diff --git a/python/mlc_llm/model/olmo/olmo_model.py b/python/mlc_llm/model/olmo/olmo_model.py index 128a75baf5..5aef0603f4 100644 --- a/python/mlc_llm/model/olmo/olmo_model.py +++ b/python/mlc_llm/model/olmo/olmo_model.py @@ -105,11 +105,11 @@ def __init__(self, config: OlmoConfig): def forward( self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int - ):# pylint: disable=W0511 + ): # pylint: disable=W0511 d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads b, s, _ = hidden_states.shape qkv = self.qkv_proj(hidden_states) - # TODO: implement qkv clipping + # TODO: implement qkv clipping # if self.qkv_clip is not None: # between qkv_clip and -qkv_clip # qkv_clip = _tensor_op._convert_scalar(self.qkv_clip, ref=qkv) # qkv_clamped = op.where(qkv < op.negative(qkv_clip), op.negative(qkv_clip), qkv)