Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: use unified activation #3619

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
env,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
ActivationFn,
)


Expand Down Expand Up @@ -332,7 +332,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.do_bn_mode = do_bn_mode
self.bn_momentum = bn_momentum
self.act = get_activation_fn(activation_function)
self.act = ActivationFn(activation_function)
self.update_g1_has_grrg = update_g1_has_grrg
self.update_g1_has_drrd = update_g1_has_drrd
self.update_g1_has_conv = update_g1_has_conv
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PairExcludeMask,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
ActivationFn,
)
from deepmd.utils.env_mat_stat import (
StatItem,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.g1_dim = g1_dim
self.g2_dim = g2_dim
self.act = get_activation_fn(activation_function)
self.act = ActivationFn(activation_function)
self.direct_dist = direct_dist
self.add_type_ebd_to_seq = add_type_ebd_to_seq
# order matters, placed after the assignment of self.ntypes
Expand Down
7 changes: 3 additions & 4 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from deepmd.pt.utils.utils import (
ActivationFn,
get_activation_fn,
)


Expand Down Expand Up @@ -470,7 +469,7 @@
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = SimpleLinear(embed_dim, embed_dim)
self.activation_fn = get_activation_fn(activation_fn)
self.activation_fn = ActivationFn(activation_fn)

Check warning on line 472 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L472

Added line #L472 was not covered by tests
self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION)

if weight is None:
Expand Down Expand Up @@ -818,7 +817,7 @@
self.fc1 = nn.Linear(
self.embed_dim, self.ffn_embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
self.activation_fn = get_activation_fn(activation)
self.activation_fn = ActivationFn(activation)

Check warning on line 820 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L820

Added line #L820 was not covered by tests
self.fc2 = nn.Linear(
self.ffn_embed_dim, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
Expand Down Expand Up @@ -1387,7 +1386,7 @@
self.ffn_dim = ffn_dim
self.attn_head = attn_head
self.activation_fn = (
get_activation_fn(activation_fn) if activation_fn is not None else None
ActivationFn(activation_fn) if activation_fn is not None else None
)
self.post_ln = post_ln
self.self_attn_layer_norm = nn.LayerNorm(
Expand Down
21 changes: 0 additions & 21 deletions deepmd/pt/utils/utils.py
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
overload,
)
Expand All @@ -18,26 +17,6 @@
from .env import PRECISION_DICT as PT_PRECISION_DICT


def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`."""
if activation.lower() == "relu":
return F.relu
elif activation.lower() == "gelu" or activation.lower() == "gelu_tf":
return lambda x: F.gelu(x, approximate="tanh")
elif activation.lower() == "tanh":
return torch.tanh
elif activation.lower() == "relu6":
return F.relu6
elif activation.lower() == "softplus":
return F.softplus
elif activation.lower() == "sigmoid":
return torch.sigmoid
elif activation.lower() == "linear" or activation.lower() == "none":
return lambda x: x
else:
raise RuntimeError(f"activation function {activation} not supported")


class ActivationFn(torch.nn.Module):
def __init__(self, activation: Optional[str]):
super().__init__()
Expand Down
6 changes: 2 additions & 4 deletions source/tests/consistent/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

if INSTALLED_PT:
from deepmd.pt.utils.utils import get_activation_fn as get_activation_fn_pt
from deepmd.pt.utils.utils import ActivationFn as ActivationFn_pt
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
Expand Down Expand Up @@ -49,8 +49,6 @@ def test_tf_consistent_with_ref(self):
def test_pt_consistent_with_ref(self):
if INSTALLED_PT:
test = to_numpy_array(
get_activation_fn_pt(self.activation)(
to_torch_tensor(self.random_input)
)
ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input))
)
np.testing.assert_allclose(self.ref, test, atol=1e-10)