From d0f3dac209e7908303e7041500f39d636b51de30 Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sat, 23 Nov 2024 23:36:55 +0800 Subject: [PATCH 1/8] [Model] Add support for OLMo architecture --- .../mlc_llm/conversation_template/__init__.py | 1 + python/mlc_llm/conversation_template/olmo.py | 26 + python/mlc_llm/interface/gen_config.py | 1 + python/mlc_llm/model/model.py | 18 + python/mlc_llm/model/olmo/__init__.py | 0 python/mlc_llm/model/olmo/olmo_loader.py | 172 ++++++ python/mlc_llm/model/olmo/olmo_model.py | 568 ++++++++++++++++++ .../mlc_llm/model/olmo/olmo_quantization.py | 94 +++ 8 files changed, 880 insertions(+) create mode 100644 python/mlc_llm/conversation_template/olmo.py create mode 100644 python/mlc_llm/model/olmo/__init__.py create mode 100644 python/mlc_llm/model/olmo/olmo_loader.py create mode 100644 python/mlc_llm/model/olmo/olmo_model.py create mode 100644 python/mlc_llm/model/olmo/olmo_quantization.py diff --git a/python/mlc_llm/conversation_template/__init__.py b/python/mlc_llm/conversation_template/__init__.py index 9873062ac7..779ef2509e 100644 --- a/python/mlc_llm/conversation_template/__init__.py +++ b/python/mlc_llm/conversation_template/__init__.py @@ -28,5 +28,6 @@ stablelm, tinyllama, wizardlm, + olmo, ) from .registry import ConvTemplateRegistry diff --git a/python/mlc_llm/conversation_template/olmo.py b/python/mlc_llm/conversation_template/olmo.py new file mode 100644 index 0000000000..89e3bc42d9 --- /dev/null +++ b/python/mlc_llm/conversation_template/olmo.py @@ -0,0 +1,26 @@ +"""OLMo default templates""" + +from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders + +from .registry import ConvTemplateRegistry + +# Note that eos_token id is "50279" both in Allenai and AMD version. +# So use the number instead of text. +# Allenai version chat_template and eos_token: https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json +# AMD version chat_template and eos_token: https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json +ConvTemplateRegistry.register_conv_template( + Conversation( + name="olmo", + system_template=f"{MessagePlaceholders.SYSTEM.value}", + system_message="", + system_prefix_token_ids=[50279], + roles={ + "user": "<|user|>", + "assistant": "<|assistant|>", + }, + seps=["\n"], + role_content_sep="\n", + role_empty_sep="\n", + stop_token_ids=[50279], + ) +) diff --git a/python/mlc_llm/interface/gen_config.py b/python/mlc_llm/interface/gen_config.py index 62d16100c9..38cc91de21 100644 --- a/python/mlc_llm/interface/gen_config.py +++ b/python/mlc_llm/interface/gen_config.py @@ -306,4 +306,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b "aya-23", "deepseek_v2", "deepseek", + "olmo", } diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index f5b05763ed..07fa8b2c57 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -39,6 +39,7 @@ from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization from .starcoder2 import starcoder2_loader, starcoder2_model, starcoder2_quantization +from .olmo import olmo_loader, olmo_model, olmo_quantization ModelConfig = Any """A ModelConfig is an object that represents a model architecture. It is required to have @@ -532,4 +533,21 @@ class Model: "ft-quant": deepseek_quantization.ft_quant, }, ), + "olmo": Model( + name="olmo", + model=olmo_model.OLMoForCausalLM, + config=olmo_model.OLMoConfig, + source={ + "huggingface-torch": olmo_loader.huggingface, + "huggingface-safetensor": olmo_loader.huggingface, + "awq": olmo_loader.awq, + }, + quantize={ + "no-quant": olmo_quantization.no_quant, + "group-quant": olmo_quantization.group_quant, + "ft-quant": olmo_quantization.ft_quant, + "awq": olmo_quantization.awq_quant, + "per-tensor-quant": olmo_quantization.per_tensor_quant, + }, + ), } diff --git a/python/mlc_llm/model/olmo/__init__.py b/python/mlc_llm/model/olmo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/mlc_llm/model/olmo/olmo_loader.py b/python/mlc_llm/model/olmo/olmo_loader.py new file mode 100644 index 0000000000..bba086c6c7 --- /dev/null +++ b/python/mlc_llm/model/olmo/olmo_loader.py @@ -0,0 +1,172 @@ +""" +This file specifies how MLC's OLMo parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" + +import functools + +import numpy as np + +from mlc_llm.loader import ExternMapping +from mlc_llm.quantization import Quantization + +from .olmo_model import OLMoConfig, OLMoForCausalLM +from .olmo_quantization import awq_quant + + +def huggingface(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : OLMoConfig + The configuration of the OLMo model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = OLMoForCausalLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping + + +def awq(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : OLMoConfig + The configuration of the OLMo model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params, _ = model.export_tvm( # type: ignore[misc] + spec=model.get_default_spec(), # type: ignore[attr-defined] + allow_extern=True, + ) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate( + [q, k, v], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate( + [gate, up], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping \ No newline at end of file diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py new file mode 100644 index 0000000000..f752f6a44b --- /dev/null +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -0,0 +1,568 @@ +""" +Implementation for OLMo architecture. +TODO: add docstring +""" + +import dataclasses +from functools import partial +from typing import Any, Dict, Optional + +from tvm import te, tir +from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import Tensor, op + +from mlc_llm import op as op_ext +from mlc_llm.nn import PagedKVCache, RopeMode +from mlc_llm.support import logging +from mlc_llm.support import tensor_parallel as tp +from mlc_llm.support.config import ConfigBase +from mlc_llm.support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class OLMoConfig(ConfigBase): + """Configuration of the OLMo model.""" + + vocab_size: int = None + hidden_size: int = None + num_attention_heads: int = None + num_key_value_heads: int = 0 + head_dim: int = 0 + position_embedding_base: int = 0 + rope_scaling: Optional[Dict[str, Any]] = None + intermediate_size: int = None + hidden_act: str = None + num_hidden_layers: int = None + tie_word_embeddings: bool = False + context_window_size: int = 0 + prefill_chunk_size: int = 0 + tensor_parallel_shards: int = 1 + pipeline_parallel_stages: int = 1 + max_batch_size: int = 1 + clip_qkv: float = None + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.num_attention_heads % self.num_key_value_heads == 0 + + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 + + if self.context_window_size == 0: + for name in ["max_position_embeddings", "max_sequence_length"]: + if name in self.kwargs: + self.context_window_size = self.kwargs.pop(name) + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("context_window_size"), + bold(name), + self.context_window_size, + ) + break + else: + raise ValueError( + "Unable to determine the maximum sequence length, because none of " + "`context_window_size`, `max_position_embeddings` or `max_sequence_length` is " + "provided in `config.json`." + ) + + if self.prefill_chunk_size == 0: + logger.info( + "%s defaults to %d", + bold("prefill_chunk_size"), + min(self.context_window_size, 8192) + ) + self.prefill_chunk_size = min(self.context_window_size, 8192) + elif self.prefill_chunk_size > self.context_window_size: + logger.info( + "Overriding %s from %d to %d", + bold("prefill_chunk_size"), + self.prefill_chunk_size, + min(self.context_window_size, 8192), + ) + self.prefill_chunk_size = min(self.context_window_size, 8192) + + if ( + self.pipeline_parallel_stages <= 0 + or self.pipeline_parallel_stages > self.num_hidden_layers + ): + raise ValueError( + f'Invalid "pipeline_parallel_stages" value({self.pipeline_parallel_stages}). ' + ) + + if self.clip_qkv is not None: + if self.clip_qkv <= 0: + raise ValueError( + f"'clip_qkv'({self.clip_qkv}) should be non-negative" + ) + + +class OLMoEebedding(nn.Embedding): + """The embedding module that can be shared with the final lm_head. From Qwen2Embedding.""" + + def lm_head_forward(self, x: nn.Tensor): + """The lm_head forwarding, which transposes the weight and multiplies + with the input tensor. + """ + weight = nn.op.permute_dims(self.weight) + return nn.op.matmul(x, weight, out_dtype="float32") + + +class OLMoAttention(nn.Module): + def __init__(self, config: OLMoConfig): + self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards + assert ( + config.num_key_value_heads >= config.tensor_parallel_shards + ), f"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}" + assert ( + config.num_key_value_heads % config.tensor_parallel_shards == 0 + ), f"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards" + self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards + self.head_dim = config.head_dim + self.qkv_proj = nn.Linear( + in_features=config.hidden_size, + out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim, + bias=False, + ) + self.clip_qkv = config.clip_qkv + self.o_proj = nn.Linear( + in_features=self.num_q_heads * self.head_dim, + out_features=config.hidden_size, + bias=False + ) + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads + b, s, _ = hidden_states.shape + + # QKV Projection + qkv = self.qkv_proj(hidden_states) + + # Clamp after qkv projection if needed + dtype = hidden_states.dtype + if self.clip_qkv is not None: + qkv = nn.maximum(qkv, nn.Tensor.from_scalar(-self.clip_qkv, dtype=dtype)) + qkv = nn.minimum(qkv, nn.Tensor.from_scalar(self.clip_qkv, dtype=dtype)) + + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) + # Attention + output = op.reshape( + paged_kv_cache.attention_with_fused_qkv(layer_id, qkv, self.num_q_heads), + (b, s, h_q * d), + ) + return self.o_proj(output) + + +# Copied from qwen2_model.ACT2FN +ACT2FN = { + "gelu": partial(nn.gelu, approximate=False), + "relu": nn.relu, + "silu": nn.silu, + "swish": nn.silu, + "gelu_new": partial(nn.gelu, approximate=True), +} + + +class OLMoFFN(nn.Module): + def __init__(self, config: OLMoConfig): + super().__init__() + if config.intermediate_size % config.tensor_parallel_shards != 0: + raise ValueError( + f"Cannot split MLP intermediate size {config.intermediate_size} " + f"evenly to {config.tensor_parallel_shards} GPUs." + ) + self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards + self.gate_up_proj = nn.Linear( + in_features=config.hidden_size, + out_features=2 * self.intermediate_size, + bias=False, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.down_proj = nn.Linear( + in_features=self.intermediate_size, + out_features=config.hidden_size, + bias=False, + ) + + def forward(self, x: Tensor): + concat_x1_x2 = self.gate_up_proj(x) + x1, x2 = op.split(concat_x1_x2, 2, axis=-1) + return self.down_proj(self.act_fn(x1) * x2) + + +class OLMoDecoderLayer(nn.Module): + def __init__(self, config: OLMoConfig): + self.input_layernorm = nn.LayerNorm( + normalized_shape = config.hidden_size, + eps= 1e-5, + elementwise_affine=False, + ) + self.self_attn = OLMoAttention(config) + self.post_attention_layernorm = nn.LayerNorm( + normalized_shape = config.hidden_size, + eps= 1e-5, + elementwise_affine=False, + ) + self.mlp = OLMoFFN(config) + + def _set_tp(): + def _set(layer, hint): + layer.weight.attrs["shard_strategy"] = hint + + hd = config.head_dim + q = self.self_attn.num_q_heads * hd + k = self.self_attn.num_kv_heads * hd + v = self.self_attn.num_kv_heads * hd + i = self.mlp.intermediate_size + _set(self.self_attn.qkv_proj, tp.ShardSingleDim("_shard_qkv", segs=[q, k, v], dim=0)) + _set(self.self_attn.o_proj, tp.ShardSingleDim("_shard_o", dim=1)) + _set(self.mlp.gate_up_proj, tp.ShardSingleDim("_shard_mlp_up", segs=[i, i], dim=0)) + _set(self.mlp.down_proj, tp.ShardSingleDim("_shard_mlp_down", dim=1)) + + self.tensor_parallel_shards = config.tensor_parallel_shards + _set_tp() + + def _apply_residual(self, out, residual): + if self.tensor_parallel_shards > 1: + return op.ccl_allreduce(out, "sum") + residual + return out + residual + + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) + hidden_states = self._apply_residual(out, residual=hidden_states) + out = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = self._apply_residual(out, residual=hidden_states) + return hidden_states + + +class OLMoModel(nn.Module): + def __init__(self, config: OLMoConfig): + assert config.hidden_size % config.num_attention_heads == 0 + self.embed_tokens = OLMoEebedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList( + [OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = nn.LayerNorm( + normalized_shape = config.hidden_size, + eps= 1e-5, + elementwise_affine=False, + ) + + self.num_layers_per_stage = ( + config.num_hidden_layers + config.pipeline_parallel_stages - 1 + ) // config.pipeline_parallel_stages + # Compute pipeline layer partition. + layers_per_stage = ( + config.num_hidden_layers + config.pipeline_parallel_stages - 1 + ) // config.pipeline_parallel_stages + self.layer_partition = [ + i * layers_per_stage for i in range(config.pipeline_parallel_stages) + ] + [config.num_hidden_layers] + + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + hidden_states = inputs + for layer_id, layer in enumerate(self.layers): + if layer_id != 0 and layer_id in self.layer_partition: + hidden_states = op_ext.pipeline_stage_boundary(hidden_states) + hidden_states = layer(hidden_states, paged_kv_cache, layer_id) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OLMoForCausalLM(nn.Module): + def __init__(self, config: OLMoConfig): + self.model = OLMoModel(config) + self.tie_word_embeddings = config.tie_word_embeddings + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.rope_theta = config.position_embedding_base + self.rope_scaling = config.rope_scaling + self.intermediate_size = config.intermediate_size + self.num_hidden_layers = config.num_hidden_layers + self.tensor_parallel_shards = config.tensor_parallel_shards + self.dtype = "float32" + + def _set_pp(): + # hidden layers + for layer_id in range(config.num_hidden_layers): + stage = layer_id // (config.num_hidden_layers // config.pipeline_parallel_stages) + for _, param in self.model.layers[layer_id].named_parameters(): + param.attrs["pipeline_stages"] = [stage] + + # embedding table and lm_head is required by all stages + all_stages = list(range(config.pipeline_parallel_stages)) + self.model.embed_tokens.weight.attrs["pipeline_stages"] = all_stages + if not config.tie_word_embeddings: + self.lm_head.weight.attrs["pipeline_stages"] = all_stages + + _set_pp() + + def to(self, dtype: Optional[str] = None): + super().to(dtype=dtype) + if dtype is not None: + self.dtype = dtype + + def batch_forward( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + logit_positions: Optional[Tensor] = None + ): + op_ext.configure() + + hidden_states = self.model(input_embeds, paged_kv_cache) + if logit_positions is not None: + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + hidden_states = op.take(hidden_states, logit_positions, axis=1) + return self.get_logits(hidden_states) + + def batch_forward_to_last_hidden_states( + self, + input_embeds: Tensor, + paged_kv_cache: PagedKVCache, + ): + op_ext.configure() + hidden_states = self.model(input_embeds, paged_kv_cache) + return hidden_states + + def embed(self, input_ids: Tensor): + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + return self.model.embed_tokens(input_ids) + + def get_logits(self, hidden_states: Tensor): + op_ext.configure() + if self.tie_word_embeddings: + logits = self.model.embed_tokens.lm_head_forward(hidden_states) + else: + logits = self.lm_head(hidden_states) + if logits.dtype != "float32": + logits = logits.astype("float32") + return logits + + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): + op_ext.configure() + if self.tensor_parallel_shards > 1: + logit_positions = op.ccl_broadcast_from_worker0(logit_positions) + hidden_states = op.take(hidden_states, logit_positions, axis=0) + return hidden_states + + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + + def _index(x: te.Tensor): # get tensor of the last sequence + b, s, d = x.shape + return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k]) + + hidden_states = self.model(input_embed, paged_kv_cache) + hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + logits = self.get_logits(hidden_states) + return logits, paged_kv_cache + + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + hidden_states = self.model(input_embed, paged_kv_cache) + logits = self.get_logits(hidden_states) + return logits, paged_kv_cache + + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + op_ext.configure() + hidden_states = self.model(input_embed, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_prefill( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) + return logits, paged_kv_cache + + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.batch_forward(input_embeds, paged_kv_cache) + return logits, paged_kv_cache + + def batch_prefill_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_decode_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def batch_verify_to_last_hidden_states( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): + hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) + return hidden_states, paged_kv_cache + + def create_paged_kv_cache( + self, + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + prefill_chunk_size: tir.Var, + page_size: tir.Var, + support_sliding_window: tir.Var + ) -> PagedKVCache: + return PagedKVCache.create_generic( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + prefill_chunk_size=prefill_chunk_size, + page_size=page_size, + support_sliding_window=support_sliding_window, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards, + num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards, + head_dim=self.head_dim, + rope_mode=RopeMode.NORMAL, + rope_scale=1, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + layer_partition=self.model.layer_partition, + dtype=self.dtype, + ) + + def get_default_spec(self): + mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor(["seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "get_logits": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_select_last_hidden_states": { + "hidden_states": nn.spec.Tensor(["seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "prefill": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "prefill_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_to_last_hidden_states": { + "input_embed": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_prefill_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_decode_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "batch_verify_to_last_hidden_states": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "prefill_chunk_size": int, + "page_size": int, + "support_sliding_window": int, + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + } + return nn.spec.ModuleSpec.from_raw(mod_spec, self) \ No newline at end of file diff --git a/python/mlc_llm/model/olmo/olmo_quantization.py b/python/mlc_llm/model/olmo/olmo_quantization.py new file mode 100644 index 0000000000..01fd3f4cd9 --- /dev/null +++ b/python/mlc_llm/model/olmo/olmo_quantization.py @@ -0,0 +1,94 @@ +"""This file specifies how MLC's OLMo parameters are quantized using group quantization +or other formats.""" + +from typing import Tuple + +from tvm.relax.frontend import nn + +from mlc_llm.loader import QuantizeMapping +from mlc_llm.quantization import ( + AWQQuantize, + FTQuantize, + GroupQuantize, + NoQuantize, + PerTensorQuantize, +) + +from .olmo_model import OLMoConfig, OLMoForCausalLM + + +def group_quant( + model_config: OLMoConfig, + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a OLMo-architecture model using group quantization.""" + model: nn.Module = OLMoForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + quantization.tensor_parallel_shards = model_config.tensor_parallel_shards + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def ft_quant( + model_config: OLMoConfig, + quantization: FTQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a OLMo-architecture model using FasterTransformer quantization.""" + model: nn.Module = OLMoForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: OLMoConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a OLMo-architecture model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = OLMoForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def no_quant( + model_config: OLMoConfig, + quantization: NoQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a OLMo model without quantization.""" + model: nn.Module = OLMoForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + return model, quant_map + + +def per_tensor_quant( + model_config: OLMoConfig, + quantization: PerTensorQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a OLMo-architecture model using per-tensor quantization.""" + model: nn.Module = OLMoForCausalLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + tensor_parallel_shards=model_config.tensor_parallel_shards, + ) + return model, quant_map \ No newline at end of file From 538caa3ad0cb705c02680295fb1d7af6d194fd9d Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 10:51:17 +0800 Subject: [PATCH 2/8] Update the 'blacked' files --- python/mlc_llm/model/olmo/olmo_loader.py | 2 +- python/mlc_llm/model/olmo/olmo_model.py | 66 +++++++++---------- .../mlc_llm/model/olmo/olmo_quantization.py | 2 +- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/python/mlc_llm/model/olmo/olmo_loader.py b/python/mlc_llm/model/olmo/olmo_loader.py index bba086c6c7..3d1a7c037c 100644 --- a/python/mlc_llm/model/olmo/olmo_loader.py +++ b/python/mlc_llm/model/olmo/olmo_loader.py @@ -169,4 +169,4 @@ def awq(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping: [mlc_name], functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), ) - return mapping \ No newline at end of file + return mapping diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index f752f6a44b..1d8d22b469 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -56,7 +56,7 @@ def __post_init__(self): self.position_embedding_base = self.kwargs.pop("rope_theta") else: self.position_embedding_base = 10000 - + if self.context_window_size == 0: for name in ["max_position_embeddings", "max_sequence_length"]: if name in self.kwargs: @@ -77,9 +77,7 @@ def __post_init__(self): if self.prefill_chunk_size == 0: logger.info( - "%s defaults to %d", - bold("prefill_chunk_size"), - min(self.context_window_size, 8192) + "%s defaults to %d", bold("prefill_chunk_size"), min(self.context_window_size, 8192) ) self.prefill_chunk_size = min(self.context_window_size, 8192) elif self.prefill_chunk_size > self.context_window_size: @@ -98,12 +96,10 @@ def __post_init__(self): raise ValueError( f'Invalid "pipeline_parallel_stages" value({self.pipeline_parallel_stages}). ' ) - + if self.clip_qkv is not None: if self.clip_qkv <= 0: - raise ValueError( - f"'clip_qkv'({self.clip_qkv}) should be non-negative" - ) + raise ValueError(f"'clip_qkv'({self.clip_qkv}) should be non-negative") class OLMoEebedding(nn.Embedding): @@ -135,11 +131,11 @@ def __init__(self, config: OLMoConfig): ) self.clip_qkv = config.clip_qkv self.o_proj = nn.Linear( - in_features=self.num_q_heads * self.head_dim, - out_features=config.hidden_size, - bias=False + in_features=self.num_q_heads * self.head_dim, + out_features=config.hidden_size, + bias=False, ) - + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape @@ -148,11 +144,9 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: qkv = self.qkv_proj(hidden_states) # Clamp after qkv projection if needed - dtype = hidden_states.dtype if self.clip_qkv is not None: - qkv = nn.maximum(qkv, nn.Tensor.from_scalar(-self.clip_qkv, dtype=dtype)) - qkv = nn.minimum(qkv, nn.Tensor.from_scalar(self.clip_qkv, dtype=dtype)) - + qkv = qkv.maximum(-self.clip_qkv).minimum(self.clip_qkv) + qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) # Attention output = op.reshape( @@ -160,7 +154,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: (b, s, h_q * d), ) return self.o_proj(output) - + # Copied from qwen2_model.ACT2FN ACT2FN = { @@ -197,27 +191,27 @@ def forward(self, x: Tensor): concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) return self.down_proj(self.act_fn(x1) * x2) - + class OLMoDecoderLayer(nn.Module): def __init__(self, config: OLMoConfig): self.input_layernorm = nn.LayerNorm( - normalized_shape = config.hidden_size, - eps= 1e-5, + normalized_shape=config.hidden_size, + eps=1e-5, elementwise_affine=False, ) self.self_attn = OLMoAttention(config) self.post_attention_layernorm = nn.LayerNorm( - normalized_shape = config.hidden_size, - eps= 1e-5, + normalized_shape=config.hidden_size, + eps=1e-5, elementwise_affine=False, ) self.mlp = OLMoFFN(config) - + def _set_tp(): def _set(layer, hint): layer.weight.attrs["shard_strategy"] = hint - + hd = config.head_dim q = self.self_attn.num_q_heads * hd k = self.self_attn.num_kv_heads * hd @@ -252,11 +246,11 @@ def __init__(self, config: OLMoConfig): [OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.norm = nn.LayerNorm( - normalized_shape = config.hidden_size, - eps= 1e-5, + normalized_shape=config.hidden_size, + eps=1e-5, elementwise_affine=False, ) - + self.num_layers_per_stage = ( config.num_hidden_layers + config.pipeline_parallel_stages - 1 ) // config.pipeline_parallel_stages @@ -267,7 +261,7 @@ def __init__(self, config: OLMoConfig): self.layer_partition = [ i * layers_per_stage for i in range(config.pipeline_parallel_stages) ] + [config.num_hidden_layers] - + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = inputs for layer_id, layer in enumerate(self.layers): @@ -302,7 +296,7 @@ def _set_pp(): stage = layer_id // (config.num_hidden_layers // config.pipeline_parallel_stages) for _, param in self.model.layers[layer_id].named_parameters(): param.attrs["pipeline_stages"] = [stage] - + # embedding table and lm_head is required by all stages all_stages = list(range(config.pipeline_parallel_stages)) self.model.embed_tokens.weight.attrs["pipeline_stages"] = all_stages @@ -320,7 +314,7 @@ def batch_forward( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, - logit_positions: Optional[Tensor] = None + logit_positions: Optional[Tensor] = None, ): op_ext.configure() @@ -344,7 +338,7 @@ def embed(self, input_ids: Tensor): if self.tensor_parallel_shards > 1: input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) - + def get_logits(self, hidden_states: Tensor): op_ext.configure() if self.tie_word_embeddings: @@ -365,7 +359,7 @@ def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() - def _index(x: te.Tensor): # get tensor of the last sequence + def _index(x: te.Tensor): # get tensor of the last sequence b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k]) @@ -415,7 +409,7 @@ def batch_decode_to_last_hidden_states( ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - + def batch_verify_to_last_hidden_states( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): @@ -428,7 +422,7 @@ def create_paged_kv_cache( max_total_seq_len: tir.Var, prefill_chunk_size: tir.Var, page_size: tir.Var, - support_sliding_window: tir.Var + support_sliding_window: tir.Var, ) -> PagedKVCache: return PagedKVCache.create_generic( max_batch_size=max_batch_size, @@ -447,7 +441,7 @@ def create_paged_kv_cache( layer_partition=self.model.layer_partition, dtype=self.dtype, ) - + def get_default_spec(self): mod_spec = { "embed": { @@ -565,4 +559,4 @@ def get_default_spec(self): }, }, } - return nn.spec.ModuleSpec.from_raw(mod_spec, self) \ No newline at end of file + return nn.spec.ModuleSpec.from_raw(mod_spec, self) diff --git a/python/mlc_llm/model/olmo/olmo_quantization.py b/python/mlc_llm/model/olmo/olmo_quantization.py index 01fd3f4cd9..16e4f1f47b 100644 --- a/python/mlc_llm/model/olmo/olmo_quantization.py +++ b/python/mlc_llm/model/olmo/olmo_quantization.py @@ -91,4 +91,4 @@ def per_tensor_quant( "", tensor_parallel_shards=model_config.tensor_parallel_shards, ) - return model, quant_map \ No newline at end of file + return model, quant_map From 08599b5164c9c9931668fb869b16a9a544d708d9 Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 11:15:15 +0800 Subject: [PATCH 3/8] Complete isort and pylint --- .../mlc_llm/conversation_template/__init__.py | 2 +- python/mlc_llm/model/model.py | 2 +- python/mlc_llm/model/olmo/olmo_model.py | 56 +++++++++---------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/python/mlc_llm/conversation_template/__init__.py b/python/mlc_llm/conversation_template/__init__.py index 779ef2509e..5a679bf7b2 100644 --- a/python/mlc_llm/conversation_template/__init__.py +++ b/python/mlc_llm/conversation_template/__init__.py @@ -20,6 +20,7 @@ llava, mistral, oasst, + olmo, orion, phi, qwen2, @@ -28,6 +29,5 @@ stablelm, tinyllama, wizardlm, - olmo, ) from .registry import ConvTemplateRegistry diff --git a/python/mlc_llm/model/model.py b/python/mlc_llm/model/model.py index 07fa8b2c57..383ad600ba 100644 --- a/python/mlc_llm/model/model.py +++ b/python/mlc_llm/model/model.py @@ -28,6 +28,7 @@ from .minicpm import minicpm_loader, minicpm_model, minicpm_quantization from .mistral import mistral_loader, mistral_model, mistral_quantization from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization +from .olmo import olmo_loader, olmo_model, olmo_quantization from .orion import orion_loader, orion_model, orion_quantization from .phi import phi_loader, phi_model, phi_quantization from .phi3 import phi3_loader, phi3_model, phi3_quantization @@ -39,7 +40,6 @@ from .rwkv6 import rwkv6_loader, rwkv6_model, rwkv6_quantization from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization from .starcoder2 import starcoder2_loader, starcoder2_model, starcoder2_quantization -from .olmo import olmo_loader, olmo_model, olmo_quantization ModelConfig = Any """A ModelConfig is an object that represents a model architecture. It is required to have diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 1d8d22b469..b51ddb7ff0 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -22,7 +22,7 @@ @dataclasses.dataclass -class OLMoConfig(ConfigBase): +class OLMoConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """Configuration of the OLMo model.""" vocab_size: int = None @@ -44,7 +44,7 @@ class OLMoConfig(ConfigBase): clip_qkv: float = None kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - def __post_init__(self): + def __post_init__(self): # pylint: disable=too-many-branches if self.num_key_value_heads == 0: self.num_key_value_heads = self.num_attention_heads if self.head_dim == 0: @@ -113,7 +113,7 @@ def lm_head_forward(self, x: nn.Tensor): return nn.op.matmul(x, weight, out_dtype="float32") -class OLMoAttention(nn.Module): +class OLMoAttention(nn.Module): # pylint: disable=missing-class-docstring def __init__(self, config: OLMoConfig): self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards assert ( @@ -136,7 +136,7 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape @@ -166,7 +166,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: } -class OLMoFFN(nn.Module): +class OLMoFFN(nn.Module): # pylint: disable=missing-class-docstring def __init__(self, config: OLMoConfig): super().__init__() if config.intermediate_size % config.tensor_parallel_shards != 0: @@ -187,13 +187,13 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, x: Tensor): + def forward(self, x: Tensor): # pylint: disable=missing-function-docstring concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) return self.down_proj(self.act_fn(x1) * x2) -class OLMoDecoderLayer(nn.Module): +class OLMoDecoderLayer(nn.Module): # pylint: disable=missing-class-docstring def __init__(self, config: OLMoConfig): self.input_layernorm = nn.LayerNorm( normalized_shape=config.hidden_size, @@ -230,7 +230,7 @@ def _apply_residual(self, out, residual): return op.ccl_allreduce(out, "sum") + residual return out + residual - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) @@ -238,7 +238,7 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: return hidden_states -class OLMoModel(nn.Module): +class OLMoModel(nn.Module): # pylint: disable=missing-class-docstring def __init__(self, config: OLMoConfig): assert config.hidden_size % config.num_attention_heads == 0 self.embed_tokens = OLMoEebedding(config.vocab_size, config.hidden_size) @@ -262,7 +262,7 @@ def __init__(self, config: OLMoConfig): i * layers_per_stage for i in range(config.pipeline_parallel_stages) ] + [config.num_hidden_layers] - def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring hidden_states = inputs for layer_id, layer in enumerate(self.layers): if layer_id != 0 and layer_id in self.layer_partition: @@ -272,7 +272,7 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): return hidden_states -class OLMoForCausalLM(nn.Module): +class OLMoForCausalLM(nn.Module): # pylint: disable=missing-class-docstring,too-many-instance-attributes def __init__(self, config: OLMoConfig): self.model = OLMoModel(config) self.tie_word_embeddings = config.tie_word_embeddings @@ -310,7 +310,7 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def batch_forward( + def batch_forward( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, @@ -325,7 +325,7 @@ def batch_forward( hidden_states = op.take(hidden_states, logit_positions, axis=1) return self.get_logits(hidden_states) - def batch_forward_to_last_hidden_states( + def batch_forward_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, @@ -334,12 +334,12 @@ def batch_forward_to_last_hidden_states( hidden_states = self.model(input_embeds, paged_kv_cache) return hidden_states - def embed(self, input_ids: Tensor): + def embed(self, input_ids: Tensor): # pylint: disable=missing-function-docstring if self.tensor_parallel_shards > 1: input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) - def get_logits(self, hidden_states: Tensor): + def get_logits(self, hidden_states: Tensor): # pylint: disable=missing-function-docstring op_ext.configure() if self.tie_word_embeddings: logits = self.model.embed_tokens.lm_head_forward(hidden_states) @@ -349,14 +349,14 @@ def get_logits(self, hidden_states: Tensor): logits = logits.astype("float32") return logits - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): # pylint: disable=missing-function-docstring op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states - def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() def _index(x: te.Tensor): # get tensor of the last sequence @@ -368,55 +368,55 @@ def _index(x: te.Tensor): # get tensor of the last sequence logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_prefill( + def batch_prefill( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache - def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_prefill_to_last_hidden_states( + def batch_prefill_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_decode_to_last_hidden_states( + def batch_decode_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_verify_to_last_hidden_states( + def batch_verify_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def create_paged_kv_cache( + def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-many-arguments,too-many-positional-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, @@ -442,7 +442,7 @@ def create_paged_kv_cache( dtype=self.dtype, ) - def get_default_spec(self): + def get_default_spec(self): # pylint: disable=missing-function-docstring mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), From 358958299b9d93932672adff54f3ba3321ed9b34 Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 11:55:47 +0800 Subject: [PATCH 4/8] Fix black pylink error --- python/mlc_llm/conversation_template/olmo.py | 6 ++- python/mlc_llm/model/olmo/olmo_model.py | 47 ++++++++++---------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/python/mlc_llm/conversation_template/olmo.py b/python/mlc_llm/conversation_template/olmo.py index 89e3bc42d9..17089dc16a 100644 --- a/python/mlc_llm/conversation_template/olmo.py +++ b/python/mlc_llm/conversation_template/olmo.py @@ -6,8 +6,10 @@ # Note that eos_token id is "50279" both in Allenai and AMD version. # So use the number instead of text. -# Allenai version chat_template and eos_token: https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json -# AMD version chat_template and eos_token: https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json +# Allenai version chat_template and eos_token: +# https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json +# AMD version chat_template and eos_token: +# https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json ConvTemplateRegistry.register_conv_template( Conversation( name="olmo", diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index b51ddb7ff0..3cbe6c1734 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -136,7 +136,7 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape @@ -187,7 +187,7 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, x: Tensor): # pylint: disable=missing-function-docstring + def forward(self, x: Tensor): concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) return self.down_proj(self.act_fn(x1) * x2) @@ -230,7 +230,7 @@ def _apply_residual(self, out, residual): return op.ccl_allreduce(out, "sum") + residual return out + residual - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) @@ -262,7 +262,7 @@ def __init__(self, config: OLMoConfig): i * layers_per_stage for i in range(config.pipeline_parallel_stages) ] + [config.num_hidden_layers] - def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): hidden_states = inputs for layer_id, layer in enumerate(self.layers): if layer_id != 0 and layer_id in self.layer_partition: @@ -272,7 +272,9 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): # pylint: disa return hidden_states -class OLMoForCausalLM(nn.Module): # pylint: disable=missing-class-docstring,too-many-instance-attributes +class OLMoForCausalLM( # pylint: disable=missing-class-docstring,too-many-instance-attributes + nn.Module +): def __init__(self, config: OLMoConfig): self.model = OLMoModel(config) self.tie_word_embeddings = config.tie_word_embeddings @@ -310,14 +312,13 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def batch_forward( # pylint: disable=missing-function-docstring + def batch_forward( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, logit_positions: Optional[Tensor] = None, ): op_ext.configure() - hidden_states = self.model(input_embeds, paged_kv_cache) if logit_positions is not None: if self.tensor_parallel_shards > 1: @@ -325,7 +326,7 @@ def batch_forward( # pylint: disable=missing-function-docstring hidden_states = op.take(hidden_states, logit_positions, axis=1) return self.get_logits(hidden_states) - def batch_forward_to_last_hidden_states( # pylint: disable=missing-function-docstring + def batch_forward_to_last_hidden_states( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, @@ -334,12 +335,12 @@ def batch_forward_to_last_hidden_states( # pylint: disable=missing-function-doc hidden_states = self.model(input_embeds, paged_kv_cache) return hidden_states - def embed(self, input_ids: Tensor): # pylint: disable=missing-function-docstring + def embed(self, input_ids: Tensor): if self.tensor_parallel_shards > 1: input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) - def get_logits(self, hidden_states: Tensor): # pylint: disable=missing-function-docstring + def get_logits(self, hidden_states: Tensor): op_ext.configure() if self.tie_word_embeddings: logits = self.model.embed_tokens.lm_head_forward(hidden_states) @@ -349,14 +350,14 @@ def get_logits(self, hidden_states: Tensor): # pylint: disable=missing-function logits = logits.astype("float32") return logits - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): # pylint: disable=missing-function-docstring + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states - def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() def _index(x: te.Tensor): # get tensor of the last sequence @@ -368,55 +369,55 @@ def _index(x: te.Tensor): # get tensor of the last sequence logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_prefill( # pylint: disable=missing-function-docstring + def batch_prefill( self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache - def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_prefill_to_last_hidden_states( # pylint: disable=missing-function-docstring + def batch_prefill_to_last_hidden_states( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_decode_to_last_hidden_states( # pylint: disable=missing-function-docstring + def batch_decode_to_last_hidden_states( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_verify_to_last_hidden_states( # pylint: disable=missing-function-docstring + def batch_verify_to_last_hidden_states( self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-many-arguments,too-many-positional-arguments + def create_paged_kv_cache( # pylint: disable=too-many-arguments,too-many-positional-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, @@ -442,7 +443,7 @@ def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-man dtype=self.dtype, ) - def get_default_spec(self): # pylint: disable=missing-function-docstring + def get_default_spec(self): mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), From 4adef05e8f92c8d7d93b3b79bf3fbeb528d536fb Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 12:01:35 +0800 Subject: [PATCH 5/8] Fix pylint and black error --- python/mlc_llm/model/olmo/olmo_model.py | 62 ++++++++++++++++--------- 1 file changed, 41 insertions(+), 21 deletions(-) diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 3cbe6c1734..322fb35ba3 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -136,7 +136,9 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + def forward( + self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int + ): # pylint: disable=missing-function-docstring d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape @@ -187,7 +189,7 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, x: Tensor): + def forward(self, x: Tensor): # pylint: disable=missing-function-docstring concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) return self.down_proj(self.act_fn(x1) * x2) @@ -230,7 +232,9 @@ def _apply_residual(self, out, residual): return op.ccl_allreduce(out, "sum") + residual return out + residual - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): + def forward( + self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int + ): # pylint: disable=missing-function-docstring out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) @@ -262,7 +266,9 @@ def __init__(self, config: OLMoConfig): i * layers_per_stage for i in range(config.pipeline_parallel_stages) ] + [config.num_hidden_layers] - def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): + def forward( + self, inputs: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring hidden_states = inputs for layer_id, layer in enumerate(self.layers): if layer_id != 0 and layer_id in self.layer_partition: @@ -312,7 +318,7 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def batch_forward( + def batch_forward( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, @@ -326,7 +332,7 @@ def batch_forward( hidden_states = op.take(hidden_states, logit_positions, axis=1) return self.get_logits(hidden_states) - def batch_forward_to_last_hidden_states( + def batch_forward_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache, @@ -335,12 +341,12 @@ def batch_forward_to_last_hidden_states( hidden_states = self.model(input_embeds, paged_kv_cache) return hidden_states - def embed(self, input_ids: Tensor): + def embed(self, input_ids: Tensor): # pylint: disable=missing-function-docstring if self.tensor_parallel_shards > 1: input_ids = op.ccl_broadcast_from_worker0(input_ids) return self.model.embed_tokens(input_ids) - def get_logits(self, hidden_states: Tensor): + def get_logits(self, hidden_states: Tensor): # pylint: disable=missing-function-docstring op_ext.configure() if self.tie_word_embeddings: logits = self.model.embed_tokens.lm_head_forward(hidden_states) @@ -350,14 +356,18 @@ def get_logits(self, hidden_states: Tensor): logits = logits.astype("float32") return logits - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): + def batch_select_last_hidden_states( + self, hidden_states: Tensor, logit_positions: Tensor + ): # pylint: disable=missing-function-docstring op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states - def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def prefill( + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring op_ext.configure() def _index(x: te.Tensor): # get tensor of the last sequence @@ -369,55 +379,65 @@ def _index(x: te.Tensor): # get tensor of the last sequence logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def decode( + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def prefill_to_last_hidden_states( + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): + def decode_to_last_hidden_states( + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_prefill( + def batch_prefill( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache ): logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache - def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + def batch_decode( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + def batch_verify( + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_prefill_to_last_hidden_states( + def batch_prefill_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_decode_to_last_hidden_states( + def batch_decode_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def batch_verify_to_last_hidden_states( + def batch_verify_to_last_hidden_states( # pylint: disable=missing-function-docstring self, input_embeds: Tensor, paged_kv_cache: PagedKVCache ): hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def create_paged_kv_cache( # pylint: disable=too-many-arguments,too-many-positional-arguments + def create_paged_kv_cache( # pylint: too-many-positional-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, @@ -443,7 +463,7 @@ def create_paged_kv_cache( # pylint: disable=too-many-arguments,too-many-positi dtype=self.dtype, ) - def get_default_spec(self): + def get_default_spec(self): # pylint: disable=missing-function-docstring mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), From 26d4093f7e5b56344979c31c7f803ac3b3613723 Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 12:32:11 +0800 Subject: [PATCH 6/8] Fix pylint error --- python/mlc_llm/model/olmo/olmo_model.py | 46 +++++++------------------ 1 file changed, 12 insertions(+), 34 deletions(-) diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 322fb35ba3..2541bbc9c7 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -136,9 +136,7 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward( - self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int - ): # pylint: disable=missing-function-docstring + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape @@ -232,9 +230,7 @@ def _apply_residual(self, out, residual): return op.ccl_allreduce(out, "sum") + residual return out + residual - def forward( - self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int - ): # pylint: disable=missing-function-docstring + def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) @@ -266,9 +262,7 @@ def __init__(self, config: OLMoConfig): i * layers_per_stage for i in range(config.pipeline_parallel_stages) ] + [config.num_hidden_layers] - def forward( - self, inputs: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring hidden_states = inputs for layer_id, layer in enumerate(self.layers): if layer_id != 0 and layer_id in self.layer_partition: @@ -278,9 +272,7 @@ def forward( return hidden_states -class OLMoForCausalLM( # pylint: disable=missing-class-docstring,too-many-instance-attributes - nn.Module -): +class OLMoForCausalLM(nn.Module): # pylint: disable=missing-class-docstring,too-many-instance-attributes def __init__(self, config: OLMoConfig): self.model = OLMoModel(config) self.tie_word_embeddings = config.tie_word_embeddings @@ -356,18 +348,14 @@ def get_logits(self, hidden_states: Tensor): # pylint: disable=missing-function logits = logits.astype("float32") return logits - def batch_select_last_hidden_states( - self, hidden_states: Tensor, logit_positions: Tensor - ): # pylint: disable=missing-function-docstring + def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): # pylint: disable=missing-function-docstring op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states - def prefill( - self, input_embed: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() def _index(x: te.Tensor): # get tensor of the last sequence @@ -379,24 +367,18 @@ def _index(x: te.Tensor): # get tensor of the last sequence logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def decode( - self, input_embed: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def prefill_to_last_hidden_states( - self, input_embed: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def decode_to_last_hidden_states( - self, input_embed: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache @@ -407,15 +389,11 @@ def batch_prefill( # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache - def batch_decode( - self, input_embeds: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_verify( - self, input_embeds: Tensor, paged_kv_cache: PagedKVCache - ): # pylint: disable=missing-function-docstring + def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache @@ -437,7 +415,7 @@ def batch_verify_to_last_hidden_states( # pylint: disable=missing-function-docs hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def create_paged_kv_cache( # pylint: too-many-positional-arguments + def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-many-arguments,too-many-positional-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, From a02a5e099fb9b4a60e16e7761b3cd52a7dcf3de3 Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 12:42:34 +0800 Subject: [PATCH 7/8] Fix pylint and black error --- python/mlc_llm/model/olmo/olmo_model.py | 53 +++++++++++++++++++------ 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 2541bbc9c7..55b6936e8d 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -136,7 +136,9 @@ def __init__(self, config: OLMoConfig): bias=False, ) - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring + def forward( # pylint: disable=missing-function-docstring + self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int + ): d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads b, s, _ = hidden_states.shape @@ -193,6 +195,9 @@ def forward(self, x: Tensor): # pylint: disable=missing-function-docstring return self.down_proj(self.act_fn(x1) * x2) +# pylint: disable=trailing-whitespace + + class OLMoDecoderLayer(nn.Module): # pylint: disable=missing-class-docstring def __init__(self, config: OLMoConfig): self.input_layernorm = nn.LayerNorm( @@ -230,7 +235,9 @@ def _apply_residual(self, out, residual): return op.ccl_allreduce(out, "sum") + residual return out + residual - def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int): # pylint: disable=missing-function-docstring + def forward( # pylint: disable=missing-function-docstring + self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int + ): out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id) hidden_states = self._apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) @@ -262,7 +269,9 @@ def __init__(self, config: OLMoConfig): i * layers_per_stage for i in range(config.pipeline_parallel_stages) ] + [config.num_hidden_layers] - def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def forward( # pylint: disable=missing-function-docstring + self, inputs: Tensor, paged_kv_cache: PagedKVCache + ): hidden_states = inputs for layer_id, layer in enumerate(self.layers): if layer_id != 0 and layer_id in self.layer_partition: @@ -272,7 +281,9 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache): # pylint: disa return hidden_states -class OLMoForCausalLM(nn.Module): # pylint: disable=missing-class-docstring,too-many-instance-attributes +class OLMoForCausalLM( # pylint: disable=missing-class-docstring,too-many-instance-attributes + nn.Module +): def __init__(self, config: OLMoConfig): self.model = OLMoModel(config) self.tie_word_embeddings = config.tie_word_embeddings @@ -348,37 +359,49 @@ def get_logits(self, hidden_states: Tensor): # pylint: disable=missing-function logits = logits.astype("float32") return logits - def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor): # pylint: disable=missing-function-docstring + def batch_select_last_hidden_states( # pylint: disable=missing-function-docstring + self, hidden_states: Tensor, logit_positions: Tensor + ): op_ext.configure() if self.tensor_parallel_shards > 1: logit_positions = op.ccl_broadcast_from_worker0(logit_positions) hidden_states = op.take(hidden_states, logit_positions, axis=0) return hidden_states - def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def prefill( # pylint: disable=missing-function-docstring + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): op_ext.configure() def _index(x: te.Tensor): # get tensor of the last sequence b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k]) + # pylint: disable=trailing-whitespace hidden_states = self.model(input_embed, paged_kv_cache) hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + # pylint: disable=trailing-whitespace + def decode( # pylint: disable=missing-function-docstring + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) logits = self.get_logits(hidden_states) return logits, paged_kv_cache - def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def prefill_to_last_hidden_states( # pylint: disable=missing-function-docstring + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache - def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def decode_to_last_hidden_states( # pylint: disable=missing-function-docstring + self, input_embed: Tensor, paged_kv_cache: PagedKVCache + ): op_ext.configure() hidden_states = self.model(input_embed, paged_kv_cache) return hidden_states, paged_kv_cache @@ -389,11 +412,15 @@ def batch_prefill( # pylint: disable=missing-function-docstring logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions) return logits, paged_kv_cache - def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def batch_decode( # pylint: disable=missing-function-docstring + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache - def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): # pylint: disable=missing-function-docstring + def batch_verify( # pylint: disable=missing-function-docstring + self, input_embeds: Tensor, paged_kv_cache: PagedKVCache + ): logits = self.batch_forward(input_embeds, paged_kv_cache) return logits, paged_kv_cache @@ -415,7 +442,7 @@ def batch_verify_to_last_hidden_states( # pylint: disable=missing-function-docs hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache) return hidden_states, paged_kv_cache - def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-many-arguments,too-many-positional-arguments + def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-many-arguments self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, @@ -441,7 +468,7 @@ def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-man dtype=self.dtype, ) - def get_default_spec(self): # pylint: disable=missing-function-docstring + def get_default_spec(self): mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"), From 046ad012d817a3267cf660d4ef50657083434047 Mon Sep 17 00:00:00 2001 From: Lanssi <962673761@qq.com> Date: Sun, 24 Nov 2024 12:46:31 +0800 Subject: [PATCH 8/8] Fix pylint error --- python/mlc_llm/model/olmo/olmo_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 55b6936e8d..06f00a0b6a 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -468,7 +468,7 @@ def create_paged_kv_cache( # pylint: disable=missing-function-docstring,too-man dtype=self.dtype, ) - def get_default_spec(self): + def get_default_spec(self): # pylint: disable=missing-function-docstring mod_spec = { "embed": { "input_ids": nn.spec.Tensor(["seq_len"], "int32"),