Skip to content

Commit

Permalink
fix: expose more dpmodel interface
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Nov 18, 2024
1 parent 3355f50 commit 5a11ff8
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 29 deletions.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from deepmd.dpmodel.atomic_model import (
DPDipoleAtomicModel,
)
from deepmd.dpmodel.model.model import (
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from deepmd.dpmodel.atomic_model import (
DPDOSAtomicModel,
)
from deepmd.dpmodel.model.model import (
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

Expand Down
79 changes: 53 additions & 26 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.fitting.dos_fitting import (
DOSFittingNet,
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
Expand All @@ -29,8 +29,35 @@
from deepmd.utils.spin import (
Spin,
)
import copy
from deepmd.dpmodel.model.dos_model import DOSModel
from deepmd.dpmodel.model.property_model import PropertyModel
from deepmd.dpmodel.model.dipole_model import DipoleModel
from deepmd.dpmodel.model.polar_model import PolarModel


def _get_standard_model_components(data, ntypes):
# descriptor
data["descriptor"]["ntypes"] = ntypes
data["descriptor"]["type_map"] = copy.deepcopy(data["type_map"])
descriptor = BaseDescriptor(**data["descriptor"])
# fitting
fitting_net = data.get("fitting_net", {})
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(data["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
return descriptor, fitting, fitting_net["type"]

def get_standard_model(data: dict) -> EnergyModel:
"""Get a EnergyModel from a dictionary.
Expand All @@ -43,36 +70,36 @@ def get_standard_model(data: dict) -> EnergyModel:
raise ValueError(
"In the DP backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details."
)
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
descriptor = BaseDescriptor(
**data["descriptor"],
data = copy.deepcopy(data)
ntypes = len(data["type_map"])
descriptor, fitting, fitting_net_type = _get_standard_model_components(
data, ntypes
)
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
elif fitting_type == "dos":
fitting = DOSFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
atom_exclude_types = data.get("atom_exclude_types", [])
pair_exclude_types = data.get("pair_exclude_types", [])


if fitting_net_type == "dipole":
modelcls = DipoleModel
elif fitting_net_type == "polar":
modelcls = PolarModel
elif fitting_net_type == "dos":
modelcls = DOSModel
elif fitting_net_type in ["ener", "direct_force_ener"]:
modelcls = EnergyModel
elif fitting_net_type == "property":
modelcls = PropertyModel
else:
raise ValueError(f"Unknown fitting type {fitting_type}") # fix
return EnergyModel(
raise RuntimeError(f"Unknown fitting type: {fitting_net_type}")

model = modelcls(
descriptor=descriptor,
fitting=fitting,
type_map=data["type_map"],
atom_exclude_types=data.get("atom_exclude_types", []),
pair_exclude_types=data.get("pair_exclude_types", []),
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
return model


def get_zbl_model(data: dict) -> DPZBLModel:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/polar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from deepmd.dpmodel.atomic_model import (
DPPolarAtomicModel,
)
from deepmd.dpmodel.model.model import (
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

Expand Down
10 changes: 10 additions & 0 deletions source/tests/consistent/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def build_tf_model(
ret["dos"],
ret["atom_dos"],
]
elif ret_key == "dipole":
ret_list = [
ret["global_dipole"],
ret["dipole"],
]
elif ret_key == "polar":
ret_list = [
ret["polar"],
ret["global_polar"],
]
else:
raise NotImplementedError
return ret_list, {
Expand Down

0 comments on commit 5a11ff8

Please sign in to comment.