diff --git a/deepmd/common.py b/deepmd/common.py index 29d32111a8..c776975591 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -52,7 +52,15 @@ _DICT_VAL = TypeVar("_DICT_VAL") _PRECISION = Literal["default", "float16", "float32", "float64"] _ACTIVATION = Literal[ - "relu", "relu6", "softplus", "sigmoid", "tanh", "gelu", "gelu_tf" + "relu", + "relu6", + "softplus", + "sigmoid", + "tanh", + "gelu", + "gelu_tf", + "none", + "linear", ] __all__.extend( [ diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index feb3355e77..6206367b1b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -10,6 +10,7 @@ datetime, ) from typing import ( + Callable, ClassVar, Dict, List, @@ -309,14 +310,7 @@ def call(self, x: np.ndarray) -> np.ndarray: """ if self.w is None or self.activation_function is None: raise ValueError("w, b, and activation_function must be set") - if self.activation_function == "tanh": - fn = np.tanh - elif self.activation_function.lower() == "none": - - def fn(x): - return x - else: - raise NotImplementedError(self.activation_function) + fn = get_activation_fn(self.activation_function) y = ( np.matmul(x, self.w) + self.b if self.b is not None @@ -332,6 +326,55 @@ def fn(x): return y +def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]: + activation_function = activation_function.lower() + if activation_function == "tanh": + return np.tanh + elif activation_function == "relu": + + def fn(x): + # https://stackoverflow.com/a/47936476/9567349 + return x * (x > 0) + + return fn + elif activation_function in ("gelu", "gelu_tf"): + + def fn(x): + # generated by GitHub Copilot + return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3))) + + return fn + elif activation_function == "relu6": + + def fn(x): + # generated by GitHub Copilot + return np.minimum(np.maximum(x, 0), 6) + + return fn + elif activation_function == "softplus": + + def fn(x): + # generated by GitHub Copilot + return np.log(1 + np.exp(x)) + + return fn + elif activation_function == "sigmoid": + + def fn(x): + # generated by GitHub Copilot + return 1 / (1 + np.exp(-x)) + + return fn + elif activation_function.lower() in ("none", "linear"): + + def fn(x): + return x + + return fn + else: + raise NotImplementedError(activation_function) + + def make_multilayer_network(T_NetworkLayer, ModuleBase): class NN(ModuleBase): """Native representation of a neural network. diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index f5a4cd84b6..10dcadadac 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -21,10 +21,16 @@ def get_activation_fn(activation: str) -> Callable: """Returns the activation function corresponding to `activation`.""" if activation.lower() == "relu": return F.relu - elif activation.lower() == "gelu": - return F.gelu + elif activation.lower() == "gelu" or activation.lower() == "gelu_tf": + return lambda x: F.gelu(x, approximate="tanh") elif activation.lower() == "tanh": return torch.tanh + elif activation.lower() == "relu6": + return F.relu6 + elif activation.lower() == "softplus": + return F.softplus + elif activation.lower() == "sigmoid": + return torch.sigmoid elif activation.lower() == "linear" or activation.lower() == "none": return lambda x: x else: @@ -42,10 +48,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.activation.lower() == "relu": return F.relu(x) - elif self.activation.lower() == "gelu": - return F.gelu(x) + elif self.activation.lower() == "gelu" or self.activation.lower() == "gelu_tf": + return F.gelu(x, approximate="tanh") elif self.activation.lower() == "tanh": return torch.tanh(x) + elif self.activation.lower() == "relu6": + return F.relu6(x) + elif self.activation.lower() == "softplus": + return F.softplus(x) + elif self.activation.lower() == "sigmoid": + return torch.sigmoid(x) elif self.activation.lower() == "linear" or self.activation.lower() == "none": return x else: diff --git a/deepmd/tf/common.py b/deepmd/tf/common.py index b1872e72ed..0d59990a29 100644 --- a/deepmd/tf/common.py +++ b/deepmd/tf/common.py @@ -135,14 +135,14 @@ def gelu_wrapper(x): "tanh": tf.nn.tanh, "gelu": gelu, "gelu_tf": gelu_tf, - "None": None, - "none": None, + "linear": lambda x: x, + "none": lambda x: x, } def get_activation_func( activation_fn: Union["_ACTIVATION", None], -) -> Union[Callable[[tf.Tensor], tf.Tensor], None]: +) -> Callable[[tf.Tensor], tf.Tensor]: """Get activation function callable based on string name. Parameters @@ -161,10 +161,11 @@ def get_activation_func( if unknown activation function is specified """ if activation_fn is None: - return None - if activation_fn not in ACTIVATION_FN_DICT: + activation_fn = "none" + assert activation_fn is not None + if activation_fn.lower() not in ACTIVATION_FN_DICT: raise RuntimeError(f"{activation_fn} is not a valid activation function") - return ACTIVATION_FN_DICT[activation_fn] + return ACTIVATION_FN_DICT[activation_fn.lower()] def get_precision(precision: "_PRECISION") -> Any: diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py new file mode 100644 index 0000000000..bb06df9082 --- /dev/null +++ b/source/tests/consistent/test_activation.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.utils.network import get_activation_fn as get_activation_fn_dp + +from .common import ( + INSTALLED_PT, + INSTALLED_TF, + parameterized, +) + +if INSTALLED_PT: + from deepmd.pt.utils.utils import get_activation_fn as get_activation_fn_pt + from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, + ) +if INSTALLED_TF: + from deepmd.tf.common import get_activation_func as get_activation_fn_tf + from deepmd.tf.env import ( + tf, + ) + + +@parameterized( + ( + "Relu", + "Relu6", + "Softplus", + "Sigmoid", + "Tanh", + "Gelu", + "Gelu_tf", + "Linear", + "None", + ), +) +class TestActivationFunctionConsistent(unittest.TestCase): + def setUp(self): + (self.activation,) = self.param + self.random_input = np.random.default_rng().normal(scale=10, size=(10, 10)) + self.ref = get_activation_fn_dp(self.activation)(self.random_input) + + @unittest.skipUnless(INSTALLED_TF, "TensorFlow is not installed") + def test_tf_consistent_with_ref(self): + if INSTALLED_TF: + place_holder = tf.placeholder(tf.float64, self.random_input.shape) + t_test = get_activation_fn_tf(self.activation)(place_holder) + with tf.Session() as sess: + test = sess.run(t_test, feed_dict={place_holder: self.random_input}) + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") + def test_pt_consistent_with_ref(self): + if INSTALLED_PT: + test = to_numpy_array( + get_activation_fn_pt(self.activation)( + to_torch_tensor(self.random_input) + ) + ) + np.testing.assert_allclose(self.ref, test, atol=1e-10)