Skip to content

Commit

Permalink
Merge branch 'devel' into dlkf
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy authored Mar 11, 2024
2 parents 4b3a7c6 + a286bd4 commit 18e32b8
Show file tree
Hide file tree
Showing 23 changed files with 221 additions and 463 deletions.
10 changes: 9 additions & 1 deletion deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,15 @@
_DICT_VAL = TypeVar("_DICT_VAL")
_PRECISION = Literal["default", "float16", "float32", "float64"]
_ACTIVATION = Literal[
"relu", "relu6", "softplus", "sigmoid", "tanh", "gelu", "gelu_tf"
"relu",
"relu6",
"softplus",
"sigmoid",
"tanh",
"gelu",
"gelu_tf",
"none",
"linear",
]
__all__.extend(
[
Expand Down
59 changes: 51 additions & 8 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
datetime,
)
from typing import (
Callable,
ClassVar,
Dict,
List,
Expand Down Expand Up @@ -309,14 +310,7 @@ def call(self, x: np.ndarray) -> np.ndarray:
"""
if self.w is None or self.activation_function is None:
raise ValueError("w, b, and activation_function must be set")
if self.activation_function == "tanh":
fn = np.tanh
elif self.activation_function.lower() == "none":

def fn(x):
return x
else:
raise NotImplementedError(self.activation_function)
fn = get_activation_fn(self.activation_function)
y = (
np.matmul(x, self.w) + self.b
if self.b is not None
Expand All @@ -332,6 +326,55 @@ def fn(x):
return y


def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.ndarray]:
activation_function = activation_function.lower()
if activation_function == "tanh":
return np.tanh
elif activation_function == "relu":

def fn(x):
# https://stackoverflow.com/a/47936476/9567349
return x * (x > 0)

return fn
elif activation_function in ("gelu", "gelu_tf"):

def fn(x):
# generated by GitHub Copilot
return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

return fn
elif activation_function == "relu6":

def fn(x):
# generated by GitHub Copilot
return np.minimum(np.maximum(x, 0), 6)

return fn
elif activation_function == "softplus":

def fn(x):
# generated by GitHub Copilot
return np.log(1 + np.exp(x))

return fn
elif activation_function == "sigmoid":

def fn(x):
# generated by GitHub Copilot
return 1 / (1 + np.exp(-x))

return fn
elif activation_function.lower() in ("none", "linear"):

def fn(x):
return x

return fn
else:
raise NotImplementedError(activation_function)


def make_multilayer_network(T_NetworkLayer, ModuleBase):
class NN(ModuleBase):
"""Native representation of a neural network.
Expand Down
3 changes: 2 additions & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,8 @@ def main():
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
pass
# help message has been printed in parse_args
return
else:
raise RuntimeError(f"unknown command {args.command}")

Expand Down
10 changes: 4 additions & 6 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def prepare_trainer_input_single(
DpLoaderSet(
validation_systems,
validation_dataset_params["batch_size"],
model_params_single,
model_params_single["type_map"],
)
if validation_systems
else None
Expand All @@ -143,13 +143,13 @@ def prepare_trainer_input_single(
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
model_params_single["type_map"],
)
else:
train_data_single = DpLoaderSet(
training_systems,
training_dataset_params["batch_size"],
model_params_single,
model_params_single["type_map"],
)
return (
train_data_single,
Expand Down Expand Up @@ -281,9 +281,7 @@ def train(FLAGS):


def freeze(FLAGS):
model = torch.jit.script(
inference.Tester(FLAGS.model, numb_test=1, head=FLAGS.head).model
)
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
torch.jit.save(
model,
FLAGS.output,
Expand Down
Loading

0 comments on commit 18e32b8

Please sign in to comment.