Skip to content

Commit

Permalink
[Model] Use tanh approximation of GeLU in Gemma MLP (#2106)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeethu authored Apr 8, 2024
1 parent cc8b747 commit 95d268b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/mlc_llm/model/gemma/gemma_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

def __post_init__(self):
if self.hidden_act != "gelu":
if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"):
raise ValueError("Only GeLU is supported as the activation for gemma.")
if self.attention_bias:
raise ValueError('Only "False" attention_bias is supported for gemma')
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self, config: GemmaConfig):
def forward(self, x: Tensor):
concat_x1_x2 = self.gate_up_proj(x)
x1, x2 = op.split(concat_x1_x2, 2, axis=-1)
return self.down_proj(op.gelu(x1) * x2)
return self.down_proj(op.gelu(x1, approximate="tanh") * x2)


class GemmaAttention(nn.Module): # pylint: disable=too-many-instance-attributes
Expand Down

0 comments on commit 95d268b

Please sign in to comment.