Skip to content

Commit

Permalink
pt: use unified activation (#3619)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Mar 28, 2024
1 parent 7933c5e commit 23f67a1
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 33 deletions.
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
env,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
ActivationFn,
)


Expand Down Expand Up @@ -332,7 +332,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.do_bn_mode = do_bn_mode
self.bn_momentum = bn_momentum
self.act = get_activation_fn(activation_function)
self.act = ActivationFn(activation_function)
self.update_g1_has_grrg = update_g1_has_grrg
self.update_g1_has_drrd = update_g1_has_drrd
self.update_g1_has_conv = update_g1_has_conv
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PairExcludeMask,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
ActivationFn,
)
from deepmd.utils.env_mat_stat import (
StatItem,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.g1_dim = g1_dim
self.g2_dim = g2_dim
self.act = get_activation_fn(activation_function)
self.act = ActivationFn(activation_function)
self.direct_dist = direct_dist
self.add_type_ebd_to_seq = add_type_ebd_to_seq
# order matters, placed after the assignment of self.ntypes
Expand Down
7 changes: 3 additions & 4 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from deepmd.pt.utils.utils import (
ActivationFn,
get_activation_fn,
)


Expand Down Expand Up @@ -470,7 +469,7 @@ class MaskLMHead(nn.Module):
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = SimpleLinear(embed_dim, embed_dim)
self.activation_fn = get_activation_fn(activation_fn)
self.activation_fn = ActivationFn(activation_fn)
self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION)

if weight is None:
Expand Down Expand Up @@ -818,7 +817,7 @@ def __init__(
self.fc1 = nn.Linear(
self.embed_dim, self.ffn_embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
self.activation_fn = get_activation_fn(activation)
self.activation_fn = ActivationFn(activation)
self.fc2 = nn.Linear(
self.ffn_embed_dim, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
Expand Down Expand Up @@ -1387,7 +1386,7 @@ def __init__(
self.ffn_dim = ffn_dim
self.attn_head = attn_head
self.activation_fn = (
get_activation_fn(activation_fn) if activation_fn is not None else None
ActivationFn(activation_fn) if activation_fn is not None else None
)
self.post_ln = post_ln
self.self_attn_layer_norm = nn.LayerNorm(
Expand Down
21 changes: 0 additions & 21 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
overload,
)
Expand All @@ -18,26 +17,6 @@
from .env import PRECISION_DICT as PT_PRECISION_DICT


def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`."""
if activation.lower() == "relu":
return F.relu
elif activation.lower() == "gelu" or activation.lower() == "gelu_tf":
return lambda x: F.gelu(x, approximate="tanh")
elif activation.lower() == "tanh":
return torch.tanh
elif activation.lower() == "relu6":
return F.relu6
elif activation.lower() == "softplus":
return F.softplus
elif activation.lower() == "sigmoid":
return torch.sigmoid
elif activation.lower() == "linear" or activation.lower() == "none":
return lambda x: x
else:
raise RuntimeError(f"activation function {activation} not supported")


class ActivationFn(torch.nn.Module):
def __init__(self, activation: Optional[str]):
super().__init__()
Expand Down
6 changes: 2 additions & 4 deletions source/tests/consistent/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

if INSTALLED_PT:
from deepmd.pt.utils.utils import get_activation_fn as get_activation_fn_pt
from deepmd.pt.utils.utils import ActivationFn as ActivationFn_pt
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
Expand Down Expand Up @@ -49,8 +49,6 @@ def test_tf_consistent_with_ref(self):
def test_pt_consistent_with_ref(self):
if INSTALLED_PT:
test = to_numpy_array(
get_activation_fn_pt(self.activation)(
to_torch_tensor(self.random_input)
)
ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input))
)
np.testing.assert_allclose(self.ref, test, atol=1e-10)

0 comments on commit 23f67a1

Please sign in to comment.