From 78189568e9285362f4cfc75ca3ebbb48983695bc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 28 Mar 2024 15:57:53 +0800 Subject: [PATCH 1/2] pt: make gelu jit happy --- deepmd/pt/utils/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 4b64a7231c..251da73ffa 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -23,7 +23,12 @@ def get_activation_fn(activation: str) -> Callable: if activation.lower() == "relu": return F.relu elif activation.lower() == "gelu" or activation.lower() == "gelu_tf": - return lambda x: F.gelu(x, approximate="tanh") + + @torch.jit.script + def gelu_tanh(x): + return F.gelu(x, approximate="tanh") + + return gelu_tanh elif activation.lower() == "tanh": return torch.tanh elif activation.lower() == "relu6": From ae475637513e3665c8f166f388fc69d5bbe0d93a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:39:26 +0800 Subject: [PATCH 2/2] use unified activation --- deepmd/pt/model/descriptor/repformer_layer.py | 4 +-- deepmd/pt/model/descriptor/repformers.py | 4 +-- deepmd/pt/model/network/network.py | 7 +++-- deepmd/pt/utils/utils.py | 26 ------------------- source/tests/consistent/test_activation.py | 6 ++--- 5 files changed, 9 insertions(+), 38 deletions(-) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 08fcb17b09..a58d6b0e2c 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -13,7 +13,7 @@ env, ) from deepmd.pt.utils.utils import ( - get_activation_fn, + ActivationFn, ) @@ -332,7 +332,7 @@ def __init__( self.set_davg_zero = set_davg_zero self.do_bn_mode = do_bn_mode self.bn_momentum = bn_momentum - self.act = get_activation_fn(activation_function) + self.act = ActivationFn(activation_function) self.update_g1_has_grrg = update_g1_has_grrg self.update_g1_has_drrd = update_g1_has_drrd self.update_g1_has_conv = update_g1_has_conv diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index a908d2e057..16a38052b1 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -29,7 +29,7 @@ PairExcludeMask, ) from deepmd.pt.utils.utils import ( - get_activation_fn, + ActivationFn, ) from deepmd.utils.env_mat_stat import ( StatItem, @@ -117,7 +117,7 @@ def __init__( self.set_davg_zero = set_davg_zero self.g1_dim = g1_dim self.g2_dim = g2_dim - self.act = get_activation_fn(activation_function) + self.act = ActivationFn(activation_function) self.direct_dist = direct_dist self.add_type_ebd_to_seq = add_type_ebd_to_seq # order matters, placed after the assignment of self.ntypes diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index 10d0364c9b..60d5251994 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -27,7 +27,6 @@ from deepmd.pt.utils.utils import ( ActivationFn, - get_activation_fn, ) @@ -470,7 +469,7 @@ class MaskLMHead(nn.Module): def __init__(self, embed_dim, output_dim, activation_fn, weight=None): super().__init__() self.dense = SimpleLinear(embed_dim, embed_dim) - self.activation_fn = get_activation_fn(activation_fn) + self.activation_fn = ActivationFn(activation_fn) self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION) if weight is None: @@ -818,7 +817,7 @@ def __init__( self.fc1 = nn.Linear( self.embed_dim, self.ffn_embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION ) - self.activation_fn = get_activation_fn(activation) + self.activation_fn = ActivationFn(activation) self.fc2 = nn.Linear( self.ffn_embed_dim, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION ) @@ -1387,7 +1386,7 @@ def __init__( self.ffn_dim = ffn_dim self.attn_head = attn_head self.activation_fn = ( - get_activation_fn(activation_fn) if activation_fn is not None else None + ActivationFn(activation_fn) if activation_fn is not None else None ) self.post_ln = post_ln self.self_attn_layer_norm = nn.LayerNorm( diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index 251da73ffa..d1ef089e49 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( - Callable, Optional, overload, ) @@ -18,31 +17,6 @@ from .env import PRECISION_DICT as PT_PRECISION_DICT -def get_activation_fn(activation: str) -> Callable: - """Returns the activation function corresponding to `activation`.""" - if activation.lower() == "relu": - return F.relu - elif activation.lower() == "gelu" or activation.lower() == "gelu_tf": - - @torch.jit.script - def gelu_tanh(x): - return F.gelu(x, approximate="tanh") - - return gelu_tanh - elif activation.lower() == "tanh": - return torch.tanh - elif activation.lower() == "relu6": - return F.relu6 - elif activation.lower() == "softplus": - return F.softplus - elif activation.lower() == "sigmoid": - return torch.sigmoid - elif activation.lower() == "linear" or activation.lower() == "none": - return lambda x: x - else: - raise RuntimeError(f"activation function {activation} not supported") - - class ActivationFn(torch.nn.Module): def __init__(self, activation: Optional[str]): super().__init__() diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 83b8494729..9dcac6746e 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -15,7 +15,7 @@ ) if INSTALLED_PT: - from deepmd.pt.utils.utils import get_activation_fn as get_activation_fn_pt + from deepmd.pt.utils.utils import ActivationFn as ActivationFn_pt from deepmd.pt.utils.utils import ( to_numpy_array, to_torch_tensor, @@ -49,8 +49,6 @@ def test_tf_consistent_with_ref(self): def test_pt_consistent_with_ref(self): if INSTALLED_PT: test = to_numpy_array( - get_activation_fn_pt(self.activation)( - to_torch_tensor(self.random_input) - ) + ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10)