Skip to content

Commit

Permalink
Merge branch 'devel' into fix/zbl_freeze
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Mar 8, 2024
2 parents 1c0a8a3 + 09bd522 commit ba493ab
Show file tree
Hide file tree
Showing 107 changed files with 3,979 additions and 390 deletions.
4 changes: 2 additions & 2 deletions deepmd/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

@Backend.register("pt")
@Backend.register("pytorch")
class TensorFlowBackend(Backend):
"""TensorFlow backend."""
class PyTorchBackend(Backend):
"""PyTorch backend."""

name = "PyTorch"
"""The formal name of the backend."""
Expand Down
8 changes: 6 additions & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class DescrptSeA(NativeOP, BaseDescriptor):
exclude_types : List[List[int]]
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
env_protection: float
Protection parameter to prevent division by zero errors during environment matrix calculations.
set_davg_zero
Set the shift of embedding net input to zero.
activation_function
Expand Down Expand Up @@ -149,6 +151,7 @@ def __init__(
trainable: bool = True,
type_one_side: bool = True,
exclude_types: List[List[int]] = [],
env_protection: float = 0.0,
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
Expand All @@ -169,6 +172,7 @@ def __init__(
self.resnet_dt = resnet_dt
self.trainable = trainable
self.type_one_side = type_one_side
self.env_protection = env_protection
self.set_davg_zero = set_davg_zero
self.activation_function = activation_function
self.precision = precision
Expand All @@ -192,7 +196,7 @@ def __init__(
self.resnet_dt,
self.precision,
)
self.env_mat = EnvMat(self.rcut, self.rcut_smth)
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.davg = np.zeros(
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
Expand Down Expand Up @@ -378,6 +382,7 @@ def serialize(self) -> dict:
"trainable": self.trainable,
"type_one_side": self.type_one_side,
"exclude_types": self.exclude_types,
"env_protection": self.env_protection,
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function,
# make deterministic
Expand Down Expand Up @@ -406,7 +411,6 @@ def deserialize(cls, data: dict) -> "DescrptSeA":
obj["davg"] = variables["davg"]
obj["dstd"] = variables["dstd"]
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj

@classmethod
Expand Down
6 changes: 4 additions & 2 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
trainable: bool = True,
type_one_side: bool = True,
exclude_types: List[List[int]] = [],
env_protection: float = 0.0,
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
self.precision = precision
self.spin = spin
self.emask = PairExcludeMask(self.ntypes, self.exclude_types)
self.env_protection = env_protection

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
Expand All @@ -150,7 +152,7 @@ def __init__(
self.resnet_dt,
self.precision,
)
self.env_mat = EnvMat(self.rcut, self.rcut_smth)
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
self.nnei = np.sum(self.sel)
self.davg = np.zeros(
[self.ntypes, self.nnei, 1], dtype=PRECISION_DICT[self.precision]
Expand Down Expand Up @@ -305,6 +307,7 @@ def serialize(self) -> dict:
"trainable": self.trainable,
"type_one_side": self.type_one_side,
"exclude_types": self.exclude_types,
"env_protection": self.env_protection,
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function,
# make deterministic
Expand Down Expand Up @@ -333,7 +336,6 @@ def deserialize(cls, data: dict) -> "DescrptSeR":
obj["davg"] = variables["davg"]
obj["dstd"] = variables["dstd"]
obj.embeddings = NetworkCollection.deserialize(embeddings)
obj.env_mat = EnvMat.deserialize(env_mat)
return obj

@classmethod
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
use_aparam_as_mask=use_aparam_as_mask,
spin=spin,
mixed_types=mixed_types,
exclude_types=exclude_types,
)

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
from .make_model import (
make_model,
)
from .spin_model import (
SpinModel,
)

__all__ = [
"DPModel",
"SpinModel",
"make_model",
]
58 changes: 56 additions & 2 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@
from deepmd.dpmodel.model.dp_model import (
DPModel,
)
from deepmd.dpmodel.model.spin_model import (
SpinModel,
)
from deepmd.utils.spin import (
Spin,
)


def get_model(data: dict) -> DPModel:
"""Get a DPModel from a dictionary.
def get_standard_model(data: dict) -> DPModel:
"""Get a standard DPModel from a dictionary.
Parameters
----------
Expand All @@ -30,6 +36,7 @@ def get_model(data: dict) -> DPModel:
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
else:
Expand All @@ -41,3 +48,50 @@ def get_model(data: dict) -> DPModel:
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
)


def get_spin_model(data: dict) -> SpinModel:
"""Get a spin model from a dictionary.
Parameters
----------
data : dict
The data to construct the model.
"""
# include virtual spin and placeholder types
data["type_map"] += [item + "_spin" for item in data["type_map"]]
spin = Spin(
use_spin=data["spin"]["use_spin"],
virtual_scale=data["spin"]["virtual_scale"],
)
pair_exclude_types = spin.get_pair_exclude_types(
exclude_types=data.get("pair_exclude_types", None)
)
data["pair_exclude_types"] = pair_exclude_types
# for descriptor data stat
data["descriptor"]["exclude_types"] = pair_exclude_types
atom_exclude_types = spin.get_atom_exclude_types(
exclude_types=data.get("atom_exclude_types", None)
)
data["atom_exclude_types"] = atom_exclude_types
if "env_protection" not in data["descriptor"]:
data["descriptor"]["env_protection"] = 1e-6
if data["descriptor"]["type"] in ["se_e2_a"]:
# only expand sel for se_e2_a
data["descriptor"]["sel"] += data["descriptor"]["sel"]
backbone_model = get_standard_model(data)
return SpinModel(backbone_model=backbone_model, spin=spin)


def get_model(data: dict):
"""Get a model from a dictionary.
Parameters
----------
data : dict
The data to construct the model.
"""
if "spin" in data:
return get_spin_model(data)
else:
return get_standard_model(data)
Loading

0 comments on commit ba493ab

Please sign in to comment.