diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index d0a167f1dc..ade4d973ce 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -85,7 +85,10 @@ def __init__( # functype if activation_fn == ACTIVATION_FN_DICT["tanh"]: self.functype = 1 - elif activation_fn == ACTIVATION_FN_DICT["gelu"]: + elif activation_fn in ( + ACTIVATION_FN_DICT["gelu"], + ACTIVATION_FN_DICT["gelu_tf"], + ): self.functype = 2 elif activation_fn == ACTIVATION_FN_DICT["relu"]: self.functype = 3