Skip to content

Commit

Permalink
Consistent activation functions between backends (#3431)
Browse files Browse the repository at this point in the history
1. add relu, gelu, gelu_tf, relu6, softplus, sigmoid, and linear to
dpmodel;
2. add gelu_tf, relu6, soft6, softplus, and sigmoid to pt;
3. change gelu in pt from non-approximate to approximate. If one still
wants to use the non-approximate version, we may consider to add a new
key;
4. add linear to tf;
5. none in tf now returns `lambda x: x` instead of `None` to be type
consistent;
6. support uppercase in all backends;
7. add consistent tests.

Signed-off-by: Jinzhe Zeng <[email protected]>
njzjz authored Mar 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent fefc0e6 commit dabbd35
Showing 5 changed files with 146 additions and 19 deletions.
10 changes: 9 additions & 1 deletion deepmd/common.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,15 @@
_DICT_VAL = TypeVar("_DICT_VAL")
_PRECISION = Literal["default", "float16", "float32", "float64"]
_ACTIVATION = Literal[
"relu", "relu6", "softplus", "sigmoid", "tanh", "gelu", "gelu_tf"
"relu",
"relu6",
"softplus",
"sigmoid",
"tanh",
"gelu",
"gelu_tf",
"none",
"linear",
]
__all__.extend(
[
59 changes: 51 additions & 8 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
datetime,
)
from typing import (
Callable,
ClassVar,
Dict,
List,
@@ -309,14 +310,7 @@ def call(self, x: np.ndarray) -> np.ndarray:
"""
if self.w is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
if self.activation_function == "tanh":
fn = np.tanh
elif self.activation_function.lower() == "none":

def fn(x):
return x
else:
raise NotImplementedError(self.activation_function)
fn = get_activation_fn(self.activation_function)
y = (
np.matmul(x, self.w) + self.b
if self.b is not None
@@ -332,6 +326,55 @@ def fn(x):
return y


def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]:
activation_function = activation_function.lower()
if activation_function == "tanh":
return np.tanh
elif activation_function == "relu":

def fn(x):
# https://stackoverflow.com/a/47936476/9567349
return x * (x > 0)

return fn
elif activation_function in ("gelu", "gelu_tf"):

def fn(x):
# generated by GitHub Copilot
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

return fn
elif activation_function == "relu6":

def fn(x):
# generated by GitHub Copilot
return np.minimum(np.maximum(x, 0), 6)

return fn
elif activation_function == "softplus":

def fn(x):
# generated by GitHub Copilot
return np.log(1 + np.exp(x))

return fn
elif activation_function == "sigmoid":

def fn(x):
# generated by GitHub Copilot
return 1 / (1 + np.exp(-x))

return fn
elif activation_function.lower() in ("none", "linear"):

def fn(x):
return x

return fn
else:
raise NotImplementedError(activation_function)


def make_multilayer_network(T_NetworkLayer, ModuleBase):
class NN(ModuleBase):
"""Native representation of a neural network.
20 changes: 16 additions & 4 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
@@ -21,10 +21,16 @@ def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`."""
if activation.lower() == "relu":
return F.relu
elif activation.lower() == "gelu":
return F.gelu
elif activation.lower() == "gelu" or activation.lower() == "gelu_tf":
return lambda x: F.gelu(x, approximate="tanh")
elif activation.lower() == "tanh":
return torch.tanh
elif activation.lower() == "relu6":
return F.relu6
elif activation.lower() == "softplus":
return F.softplus
elif activation.lower() == "sigmoid":
return torch.sigmoid
elif activation.lower() == "linear" or activation.lower() == "none":
return lambda x: x
else:
@@ -42,10 +48,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

if self.activation.lower() == "relu":
return F.relu(x)
elif self.activation.lower() == "gelu":
return F.gelu(x)
elif self.activation.lower() == "gelu" or self.activation.lower() == "gelu_tf":
return F.gelu(x, approximate="tanh")
elif self.activation.lower() == "tanh":
return torch.tanh(x)
elif self.activation.lower() == "relu6":
return F.relu6(x)
elif self.activation.lower() == "softplus":
return F.softplus(x)
elif self.activation.lower() == "sigmoid":
return torch.sigmoid(x)
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
return x
else:
13 changes: 7 additions & 6 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
@@ -135,14 +135,14 @@ def gelu_wrapper(x):
"tanh": tf.nn.tanh,
"gelu": gelu,
"gelu_tf": gelu_tf,
"None": None,
"none": None,
"linear": lambda x: x,
"none": lambda x: x,
}


def get_activation_func(
activation_fn: Union["_ACTIVATION", None],
) -> Union[Callable[[tf.Tensor], tf.Tensor], None]:
) -> Callable[[tf.Tensor], tf.Tensor]:
"""Get activation function callable based on string name.
Parameters
@@ -161,10 +161,11 @@ def get_activation_func(
if unknown activation function is specified
"""
if activation_fn is None:
return None
if activation_fn not in ACTIVATION_FN_DICT:
activation_fn = "none"
assert activation_fn is not None
if activation_fn.lower() not in ACTIVATION_FN_DICT:
raise RuntimeError(f"{activation_fn} is not a valid activation function")
return ACTIVATION_FN_DICT[activation_fn]
return ACTIVATION_FN_DICT[activation_fn.lower()]


def get_precision(precision: "_PRECISION") -> Any:
63 changes: 63 additions & 0 deletions source/tests/consistent/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.dpmodel.utils.network import get_activation_fn as get_activation_fn_dp

from .common import (
INSTALLED_PT,
INSTALLED_TF,
parameterized,
)

if INSTALLED_PT:
from deepmd.pt.utils.utils import get_activation_fn as get_activation_fn_pt
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)
if INSTALLED_TF:
from deepmd.tf.common import get_activation_func as get_activation_fn_tf
from deepmd.tf.env import (
tf,
)


@parameterized(
(
"Relu",
"Relu6",
"Softplus",
"Sigmoid",
"Tanh",
"Gelu",
"Gelu_tf",
"Linear",
"None",
),
)
class TestActivationFunctionConsistent(unittest.TestCase):
def setUp(self):
(self.activation,) = self.param
self.random_input = np.random.default_rng().normal(scale=10, size=(10, 10))
self.ref = get_activation_fn_dp(self.activation)(self.random_input)

@unittest.skipUnless(INSTALLED_TF, "TensorFlow is not installed")
def test_tf_consistent_with_ref(self):
if INSTALLED_TF:
place_holder = tf.placeholder(tf.float64, self.random_input.shape)
t_test = get_activation_fn_tf(self.activation)(place_holder)
with tf.Session() as sess:
test = sess.run(t_test, feed_dict={place_holder: self.random_input})
np.testing.assert_allclose(self.ref, test, atol=1e-10)

@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
def test_pt_consistent_with_ref(self):
if INSTALLED_PT:
test = to_numpy_array(
get_activation_fn_pt(self.activation)(
to_torch_tensor(self.random_input)
)
)
np.testing.assert_allclose(self.ref, test, atol=1e-10)

0 comments on commit dabbd35

Please sign in to comment.