diff --git a/deepmd/dpmodel/model/dipole_model.py b/deepmd/dpmodel/model/dipole_model.py index 211debc2e8..4ca523f79b 100644 --- a/deepmd/dpmodel/model/dipole_model.py +++ b/deepmd/dpmodel/model/dipole_model.py @@ -4,7 +4,7 @@ from deepmd.dpmodel.atomic_model import ( DPDipoleAtomicModel, ) -from deepmd.dpmodel.model.model import ( +from deepmd.dpmodel.model.base_model import ( BaseModel, ) diff --git a/deepmd/dpmodel/model/dos_model.py b/deepmd/dpmodel/model/dos_model.py index 638bdb462d..3df887b460 100644 --- a/deepmd/dpmodel/model/dos_model.py +++ b/deepmd/dpmodel/model/dos_model.py @@ -3,7 +3,7 @@ from deepmd.dpmodel.atomic_model import ( DPDOSAtomicModel, ) -from deepmd.dpmodel.model.model import ( +from deepmd.dpmodel.model.base_model import ( BaseModel, ) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 28e29cdcb7..43eeb13898 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -8,8 +8,8 @@ from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.dpmodel.fitting.dos_fitting import ( - DOSFittingNet, +from deepmd.dpmodel.fitting.base_fitting import ( + BaseFitting, ) from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, @@ -29,8 +29,35 @@ from deepmd.utils.spin import ( Spin, ) +import copy +from deepmd.dpmodel.model.dos_model import DOSModel +from deepmd.dpmodel.model.property_model import PropertyModel +from deepmd.dpmodel.model.dipole_model import DipoleModel +from deepmd.dpmodel.model.polar_model import PolarModel +def _get_standard_model_components(data, ntypes): + # descriptor + data["descriptor"]["ntypes"] = ntypes + data["descriptor"]["type_map"] = copy.deepcopy(data["type_map"]) + descriptor = BaseDescriptor(**data["descriptor"]) + # fitting + fitting_net = data.get("fitting_net", {}) + fitting_net["type"] = fitting_net.get("type", "ener") + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(data["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + if fitting_net["type"] in ["dipole", "polar"]: + fitting_net["embedding_width"] = descriptor.get_dim_emb() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + grad_force = "direct" not in fitting_net["type"] + if not grad_force: + fitting_net["out_dim"] = descriptor.get_dim_emb() + if "ener" in fitting_net["type"]: + fitting_net["return_energy"] = True + fitting = BaseFitting(**fitting_net) + return descriptor, fitting, fitting_net["type"] + def get_standard_model(data: dict) -> EnergyModel: """Get a EnergyModel from a dictionary. @@ -43,36 +70,36 @@ def get_standard_model(data: dict) -> EnergyModel: raise ValueError( "In the DP backend, type_embedding is not at the model level, but within the descriptor. See type embedding documentation for details." ) - data["descriptor"]["type_map"] = data["type_map"] - data["descriptor"]["ntypes"] = len(data["type_map"]) - fitting_type = data["fitting_net"].pop("type") - data["fitting_net"]["type_map"] = data["type_map"] - descriptor = BaseDescriptor( - **data["descriptor"], + data = copy.deepcopy(data) + ntypes = len(data["type_map"]) + descriptor, fitting, fitting_net_type = _get_standard_model_components( + data, ntypes ) - if fitting_type == "ener": - fitting = EnergyFittingNet( - ntypes=descriptor.get_ntypes(), - dim_descrpt=descriptor.get_dim_out(), - mixed_types=descriptor.mixed_types(), - **data["fitting_net"], - ) - elif fitting_type == "dos": - fitting = DOSFittingNet( - ntypes=descriptor.get_ntypes(), - dim_descrpt=descriptor.get_dim_out(), - mixed_types=descriptor.mixed_types(), - **data["fitting_net"], - ) + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + + + if fitting_net_type == "dipole": + modelcls = DipoleModel + elif fitting_net_type == "polar": + modelcls = PolarModel + elif fitting_net_type == "dos": + modelcls = DOSModel + elif fitting_net_type in ["ener", "direct_force_ener"]: + modelcls = EnergyModel + elif fitting_net_type == "property": + modelcls = PropertyModel else: - raise ValueError(f"Unknown fitting type {fitting_type}") # fix - return EnergyModel( + raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") + + model = modelcls( descriptor=descriptor, fitting=fitting, type_map=data["type_map"], - atom_exclude_types=data.get("atom_exclude_types", []), - pair_exclude_types=data.get("pair_exclude_types", []), + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, ) + return model def get_zbl_model(data: dict) -> DPZBLModel: diff --git a/deepmd/dpmodel/model/polar_model.py b/deepmd/dpmodel/model/polar_model.py index c7cce1c1fa..994b3556c2 100644 --- a/deepmd/dpmodel/model/polar_model.py +++ b/deepmd/dpmodel/model/polar_model.py @@ -3,7 +3,7 @@ from deepmd.dpmodel.atomic_model import ( DPPolarAtomicModel, ) -from deepmd.dpmodel.model.model import ( +from deepmd.dpmodel.model.base_model import ( BaseModel, ) diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index ef1c7cf911..bb38abc5b6 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -66,6 +66,16 @@ def build_tf_model( ret["dos"], ret["atom_dos"], ] + elif ret_key == "dipole": + ret_list = [ + ret["global_dipole"], + ret["dipole"], + ] + elif ret_key == "polar": + ret_list = [ + ret["polar"], + ret["global_polar"], + ] else: raise NotImplementedError return ret_list, {