From 95d268bf1c072206a6ae4e51143fbfc263c0d7b6 Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Mon, 8 Apr 2024 20:36:59 +0100 Subject: [PATCH] [Model] Use tanh approximation of GeLU in Gemma MLP (#2106) This is in line with the implementation in the [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L183) library. Also, the [gemma-1.1](https://huggingface.co/google/gemma-1.1-2b-it/blob/main/config.json#L10) model config. --- python/mlc_llm/model/gemma/gemma_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 5950ab2972..118f3ce856 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -39,7 +39,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): - if self.hidden_act != "gelu": + if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"): raise ValueError("Only GeLU is supported as the activation for gemma.") if self.attention_bias: raise ValueError('Only "False" attention_bias is supported for gemma') @@ -115,7 +115,7 @@ def __init__(self, config: GemmaConfig): def forward(self, x: Tensor): concat_x1_x2 = self.gate_up_proj(x) x1, x2 = op.split(concat_x1_x2, 2, axis=-1) - return self.down_proj(op.gelu(x1) * x2) + return self.down_proj(op.gelu(x1, approximate="tanh") * x2) class GemmaAttention(nn.Module): # pylint: disable=too-many-instance-attributes