diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6a1d303f64..30efa6b062 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.9 + rev: v0.7.0 hooks: - id: ruff args: ["--fix"] @@ -52,7 +52,7 @@ repos: - id: blacken-docs # C++ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v19.1.1 + rev: v19.1.2 hooks: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$) @@ -66,7 +66,7 @@ repos: exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt - rev: v3.9.0-1 + rev: v3.10.0-1 hooks: - id: shfmt # CMake diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index c29a76b3f1..6307b19f41 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy +import math from typing import ( Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( NativeOP, + to_numpy_array, ) from deepmd.dpmodel.output_def import ( FittingOutputDef, @@ -172,17 +174,18 @@ def forward_common_atomic( ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. """ + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) # exclude neighbors in the nlist - nlist = np.where(pair_mask == 1, nlist, -1) + nlist = xp.where(pair_mask == 1, nlist, -1) ext_atom_mask = self.make_atom_mask(extended_atype) ret_dict = self.forward_atomic( extended_coord, - np.where(ext_atom_mask, extended_atype, 0), + xp.where(ext_atom_mask, extended_atype, 0), nlist, mapping=mapping, fparam=fparam, @@ -191,13 +194,13 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) + atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) if self.atom_excl is not None: atom_mask *= self.atom_excl.build_type_exclude_mask(atype) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape - out_shape2 = np.prod(out_shape[2:]) + out_shape2 = math.prod(out_shape[2:]) ret_dict[kk] = ( ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] @@ -232,14 +235,15 @@ def serialize(self) -> dict: "rcond": self.rcond, "preset_out_bias": self.preset_out_bias, "@variables": { - "out_bias": self.out_bias, - "out_std": self.out_std, + "out_bias": to_numpy_array(self.out_bias), + "out_std": to_numpy_array(self.out_std), }, } @classmethod def deserialize(cls, data: dict) -> "BaseAtomicModel": - data = copy.deepcopy(data) + # do not deep copy Descriptor and Fitting class + data = data.copy() variables = data.pop("@variables") obj = cls(**data) for kk in variables.keys(): diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 7e576eb484..fe049021fe 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -169,14 +169,20 @@ def serialize(self) -> dict: ) return dd + # for subclass overriden + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") - descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) - fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor")) + fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting")) data["descriptor"] = descriptor_obj data["fitting"] = fitting_obj obj = super().deserialize(data) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 8cdb7e1f25..dc90f10da7 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -3,6 +3,7 @@ Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.atomic_model.base_atomic_model import ( @@ -75,7 +76,8 @@ def __init__( else: self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT - self.reverse_precision_dict = RESERVED_PRECISON_DICT + # not supported by flax + # self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION @@ -253,9 +255,7 @@ def input_type_cast( str, ]: """Cast the input data to global float type.""" - input_prec = self.reverse_precision_dict[ - self.precision_dict[coord.dtype.name] - ] + input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]] ### ### type checking would not pass jit, convert to coord prec anyway ### @@ -264,10 +264,7 @@ def input_type_cast( for vv in [box, fparam, aparam] ] box, fparam, aparam = _lst - if ( - input_prec - == self.reverse_precision_dict[self.global_np_float_precision] - ): + if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]: return coord, box, fparam, aparam, input_prec else: pp = self.global_np_float_precision @@ -286,8 +283,7 @@ def output_type_cast( ) -> dict[str, np.ndarray]: """Convert the model output to the input prec.""" do_cast = ( - input_prec - != self.reverse_precision_dict[self.global_np_float_precision] + input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision] ) pp = self.precision_dict[input_prec] odef = self.model_output_def() @@ -366,6 +362,7 @@ def _format_nlist( nnei: int, extra_nlist_sort: bool = False, ): + xp = array_api_compat.array_namespace(extended_coord, nlist) n_nf, n_nloc, n_nnei = nlist.shape extended_coord = extended_coord.reshape([n_nf, -1, 3]) nall = extended_coord.shape[1] @@ -373,10 +370,10 @@ def _format_nlist( if n_nnei < nnei: # make a copy before revise - ret = np.concatenate( + ret = xp.concat( [ nlist, - -1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + -1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), ], axis=-1, ) @@ -385,16 +382,16 @@ def _format_nlist( n_nf, n_nloc, n_nnei = nlist.shape # make a copy before revise m_real_nei = nlist >= 0 - ret = np.where(m_real_nei, nlist, 0) + ret = xp.where(m_real_nei, nlist, 0) coord0 = extended_coord[:, :n_nloc, :] index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) - coord1 = np.take_along_axis(extended_coord, index, axis=1) + coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr = np.where(m_real_nei, rr, float("inf")) - rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1) - ret = np.take_along_axis(ret, ret_mapping, axis=2) - ret = np.where(rr > rcut, -1, ret) + rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = xp.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) + ret = xp.take_along_axis(ret, ret_mapping, axis=2) + ret = xp.where(rr > rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 43c275b1be..928c33f3bd 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -23,6 +24,7 @@ def fit_output_to_model_output( the model output. """ + xp = array_api_compat.get_namespace(coord_ext) model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] @@ -31,7 +33,7 @@ def fit_output_to_model_output( if vdef.reducible: kk_redu = get_reduce_name(kk) # cast to energy prec brefore reduction - model_ret[kk_redu] = np.sum( + model_ret[kk_redu] = xp.sum( vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis ) if vdef.r_differentiable: diff --git a/deepmd/jax/atomic_model/__init__.py b/deepmd/jax/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py new file mode 100644 index 0000000000..90920879c2 --- /dev/null +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + + +def base_atomic_model_set_attr(name, value): + if name in {"out_bias", "out_std"}: + value = to_jax_array(value) + elif name == "pair_excl" and value is not None: + value = PairExcludeMask(value.ntypes, value.exclude_types) + elif name == "atom_excl" and value is not None: + value = AtomExcludeMask(value.ntypes, value.exclude_types) + return value diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py new file mode 100644 index 0000000000..077209e29a --- /dev/null +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) + + +@flax_module +class DPAtomicModel(DPAtomicModelDP): + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 6ceb116d85..ed59493268 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -1 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", + "DescrptDPA1", +] diff --git a/deepmd/jax/descriptor/base_descriptor.py b/deepmd/jax/descriptor/base_descriptor.py new file mode 100644 index 0000000000..7dec3cd6d4 --- /dev/null +++ b/deepmd/jax/descriptor/base_descriptor.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) +from deepmd.jax.env import ( + jnp, +) + +BaseDescriptor = make_base_descriptor(jnp.ndarray) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index a9b0404970..0528e4bb93 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -16,6 +16,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") @flax_module class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index a60a4e9af1..d1a6e9a8d9 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -8,6 +8,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -16,6 +19,8 @@ ) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") @flax_module class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index 6ceb116d85..e72314dcab 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.fitting.fitting import ( + DOSFittingNet, + EnergyFittingNet, +) + +__all__ = [ + "EnergyFittingNet", + "DOSFittingNet", +] diff --git a/deepmd/jax/fitting/base_fitting.py b/deepmd/jax/fitting/base_fitting.py new file mode 100644 index 0000000000..fd9f3a416d --- /dev/null +++ b/deepmd/jax/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.fitting.make_base_fitting import ( + make_base_fitting, +) +from deepmd.jax.env import ( + jnp, +) + +BaseFitting = make_base_fitting(jnp.ndarray) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 284213c70a..f979db4d41 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) from deepmd.jax.utils.exclude_mask import ( AtomExcludeMask, ) @@ -33,6 +36,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: return value +@BaseFitting.register("ener") @flax_module class EnergyFittingNet(EnergyFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: @@ -40,6 +44,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py new file mode 100644 index 0000000000..05a60c4ffe --- /dev/null +++ b/deepmd/jax/model/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .ener_model import ( + EnergyModel, +) + +__all__ = ["EnergyModel"] diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py new file mode 100644 index 0000000000..fee4855da3 --- /dev/null +++ b/deepmd/jax/model/base_model.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.model.base_model import ( + make_base_model, +) + +BaseModel = make_base_model() diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py new file mode 100644 index 0000000000..79c5a29e88 --- /dev/null +++ b/deepmd/jax/model/ener_model.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.model import EnergyModel as EnergyModelDP +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +@BaseModel.register("ener") +@flax_module +class EnergyModel(EnergyModelDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = DPAtomicModel.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py new file mode 100644 index 0000000000..7fa3efda6e --- /dev/null +++ b/deepmd/jax/model/model.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def get_standard_model(data: dict): + """Get a Model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + data = deepcopy(data) + descriptor_type = data["descriptor"].pop("type") + data["descriptor"]["type_map"] = data["type_map"] + fitting_type = data["fitting_net"].pop("type") + data["fitting_net"]["type_map"] = data["type_map"] + descriptor = BaseDescriptor.get_class_by_type(descriptor_type)( + **data["descriptor"], + ) + fitting = BaseFitting.get_class_by_type(fitting_type)( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + return BaseModel.get_class_by_type(fitting_type)( + descriptor=descriptor, + fitting=fitting, + type_map=data["type_map"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=data.get("pair_exclude_types", []), + ) + + +def get_model(data: dict): + """Get a model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + model_type = data.get("type", "standard") + if model_type == "standard": + if "spin" in data: + raise NotImplementedError("Spin model is not implemented yet.") + else: + return get_standard_model(data) + else: + return BaseModel.get_class_by_type(model_type).get_model(data) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7c8a95c5e7..71f81b0a12 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -283,7 +283,9 @@ def train( # update init_model or init_frz_model config if necessary if (init_model is not None or init_frz_model is not None) and use_pretrain_script: if init_model is not None: - init_state_dict = torch.load(init_model, map_location=DEVICE) + init_state_dict = torch.load( + init_model, map_location=DEVICE, weights_only=True + ) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] @@ -380,7 +382,9 @@ def change_bias( output: Optional[str] = None, ): if input_file.endswith(".pt"): - old_state_dict = torch.load(input_file, map_location=env.DEVICE) + old_state_dict = torch.load( + input_file, map_location=env.DEVICE, weights_only=True + ) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] elif input_file.endswith(".pth"): diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 084c9b282d..1d5f086ff7 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -103,7 +103,9 @@ def __init__( self.output_def = output_def self.model_path = model_file if str(self.model_path).endswith(".pt"): - state_dict = torch.load(model_file, map_location=env.DEVICE) + state_dict = torch.load( + model_file, map_location=env.DEVICE, weights_only=True + ) if "model" in state_dict: state_dict = state_dict["model"] self.input_param = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index 4204020a0d..1e6b138764 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -34,7 +34,7 @@ def __init__( - config: The Dict-like configuration with training options. """ # Model - state_dict = torch.load(model_ckpt, map_location=DEVICE) + state_dict = torch.load(model_ckpt, map_location=DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 481b612557..2ebebd8400 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -411,7 +411,9 @@ def get_lr(lr_params): optimizer_state_dict = None if resuming: log.info(f"Resuming from {resume_model}.") - state_dict = torch.load(resume_model, map_location=DEVICE) + state_dict = torch.load( + resume_model, map_location=DEVICE, weights_only=True + ) if "model" in state_dict: optimizer_state_dict = ( state_dict["optimizer"] if finetune_model is None else None diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 2dd2230b54..96a420bf6a 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -136,7 +136,7 @@ def get_finetune_rules( Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. """ multi_task = "model_dict" in model_config - state_dict = torch.load(finetune_model, map_location=env.DEVICE) + state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] last_model_params = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index aab6d100a5..1c6ea096aa 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -33,7 +33,7 @@ def serialize_from_file(model_file: str) -> dict: model = get_model(model_def_script) model.load_state_dict(saved_model.state_dict()) elif model_file.endswith(".pt"): - state_dict = torch.load(model_file, map_location="cpu") + state_dict = torch.load(model_file, map_location="cpu", weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_def_script = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index dd69dc40bc..e7d68e85f3 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1387,24 +1387,27 @@ def descrpt_se_a_mask_args(): def descrpt_variant_type_args(exclude_hybrid: bool = False) -> Variant: - link_lf = make_link("loc_frame", "model/descriptor[loc_frame]") - link_se_e2_a = make_link("se_e2_a", "model/descriptor[se_e2_a]") - link_se_e2_r = make_link("se_e2_r", "model/descriptor[se_e2_r]") - link_se_e3 = make_link("se_e3", "model/descriptor[se_e3]") - link_se_a_tpe = make_link("se_a_tpe", "model/descriptor[se_a_tpe]") - link_hybrid = make_link("hybrid", "model/descriptor[hybrid]") - link_se_atten = make_link("se_atten", "model/descriptor[se_atten]") - link_se_atten_v2 = make_link("se_atten_v2", "model/descriptor[se_atten_v2]") - doc_descrpt_type = "The type of the descritpor. See explanation below. \n\n\ -- `loc_frame`: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\ -- `se_e2_a`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\ -- `se_e2_r`: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\ -- `se_e3`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\ -- `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\ -- `se_atten`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism will be used by this descriptor.\n\n\ -- `se_atten_v2`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism with new modifications will be used by this descriptor.\n\n\ -- `se_a_mask`: Used by the smooth edition of Deep Potential. It can accept a variable number of atoms in a frame (Non-PBC system). *aparam* are required as an indicator matrix for the real/virtual sign of input atoms. \n\n\ -- `hybrid`: Concatenate of a list of descriptors as a new descriptor." + link_lf = make_link("loc_frame", "model[standard]/descriptor[loc_frame]") + link_se_e2_a = make_link("se_e2_a", "model[standard]/descriptor[se_e2_a]") + link_se_e2_r = make_link("se_e2_r", "model[standard]/descriptor[se_e2_r]") + link_se_e3 = make_link("se_e3", "model[standard]/descriptor[se_e3]") + link_se_a_tpe = make_link("se_a_tpe", "model[standard]/descriptor[se_a_tpe]") + link_hybrid = make_link("hybrid", "model[standard]/descriptor[hybrid]") + link_se_atten = make_link("se_atten", "model[standard]/descriptor[se_atten]") + link_se_atten_v2 = make_link( + "se_atten_v2", "model[standard]/descriptor[se_atten_v2]" + ) + link_se_a_mask = make_link("se_a_mask", "model[standard]/descriptor[se_a_mask]") + doc_descrpt_type = f"The type of the descritpor. See explanation below. \n\n\ +- {link_lf}: Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame.\n\n\ +- {link_se_e2_a}: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor.\n\n\ +- {link_se_e2_r}: Used by the smooth edition of Deep Potential. Only the distance between atoms is used to construct the descriptor.\n\n\ +- {link_se_e3}: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Three-body embedding will be used by this descriptor.\n\n\ +- {link_se_a_tpe}: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\ +- {link_se_atten}: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism will be used by this descriptor.\n\n\ +- {link_se_atten_v2}: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism with new modifications will be used by this descriptor.\n\n\ +- {link_se_a_mask}: Used by the smooth edition of Deep Potential. It can accept a variable number of atoms in a frame (Non-PBC system). *aparam* are required as an indicator matrix for the real/virtual sign of input atoms. \n\n\ +- {link_hybrid}: Concatenate of a list of descriptors as a new descriptor." return Variant( "type", @@ -1692,7 +1695,7 @@ def fitting_variant_type_args(): # --- Modifier configurations: --- # def modifier_dipole_charge(): doc_model_name = "The name of the frozen dipole model file." - doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model/fitting_net[dipole]/sel_type')}. " + doc_model_charge_map = f"The charge of the WFCC. The list length should be the same as the {make_link('sel_type', 'model[standard]/fitting_net[dipole]/sel_type')}. " doc_sys_charge_map = f"The charge of real atoms. The list length should be the same as the {make_link('type_map', 'model/type_map')}" doc_ewald_h = "The grid spacing of the FFT grid. Unit is A" doc_ewald_beta = f"The splitting parameter of Ewald sum. Unit is A^{-1}" diff --git a/doc/freeze/compress.md b/doc/freeze/compress.md index 3cce96c993..e26c85e45a 100644 --- a/doc/freeze/compress.md +++ b/doc/freeze/compress.md @@ -99,7 +99,7 @@ The model compression interface requires the version of DeePMD-kit used in the o Descriptors with `se_e2_a`, `se_e3`, `se_e2_r` and `se_atten_v2` types are supported by the model compression feature. `Hybrid` mixed with the above descriptors is also supported. -Notice: Model compression for the `se_atten_v2` descriptor is exclusively designed for models with the training parameter {ref}`attn_layer ` set to 0. +Notice: Model compression for the `se_atten_v2` descriptor is exclusively designed for models with the training parameter {ref}`attn_layer ` set to 0. **Available activation functions for descriptor:** diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 4079a8d424..3f65375865 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -198,7 +198,7 @@ Enable compilation optimization for the native machine's CPU type. Do not enable **Type**: string -Control high (double) or low (float) precision of training. +Additional CMake arguments. ::: :::{envvar} FLAGS diff --git a/doc/model/dplr.md b/doc/model/dplr.md index ec95f9f424..91c2251346 100644 --- a/doc/model/dplr.md +++ b/doc/model/dplr.md @@ -58,7 +58,7 @@ Two settings make the training input script different from an energy training in }, ``` -The type of fitting is set to {ref}`dipole `. The dipole is associated with type 0 atoms (oxygens), by the setting `"dipole_type": [0]`. What we trained is the displacement of the WC from the corresponding oxygen atom. It shares the same training input as the atomic dipole because both are 3-dimensional vectors defined on atoms. +The type of fitting is set to {ref}`dipole `. The dipole is associated with type 0 atoms (oxygens), by the setting `"dipole_type": [0]`. What we trained is the displacement of the WC from the corresponding oxygen atom. It shares the same training input as the atomic dipole because both are 3-dimensional vectors defined on atoms. The loss section is provided as follows ```json diff --git a/doc/model/dprc.md b/doc/model/dprc.md index 33dde237d7..d9ce24b600 100644 --- a/doc/model/dprc.md +++ b/doc/model/dprc.md @@ -140,7 +140,7 @@ As described in the paper, the DPRc model only corrects $E_\text{QM}$ and $E_\te :::: -{ref}`exclude_types ` can be generated by the following Python script: +{ref}`exclude_types ` can be generated by the following Python script: ```py from itertools import combinations_with_replacement, product @@ -163,7 +163,7 @@ print( ) ``` -Also, DPRc assumes MM atom energies ({ref}`atom_ener `) are zero: +Also, DPRc assumes MM atom energies ({ref}`atom_ener `) are zero: ```json "fitting_net": { @@ -173,7 +173,7 @@ Also, DPRc assumes MM atom energies ({ref}`atom_ener ` only works when {ref}`descriptor/set_davg_zero ` of the QM/MM part is `true`. +Note that {ref}`atom_ener ` only works when {ref}`descriptor/set_davg_zero ` of the QM/MM part is `true`. ## Run MD simulations diff --git a/doc/model/overall.md b/doc/model/overall.md index b0ea19ac5d..8499c64b79 100644 --- a/doc/model/overall.md +++ b/doc/model/overall.md @@ -42,7 +42,7 @@ A model has two parts, a descriptor that maps atomic configuration to a set of s } ``` -The two subsections, {ref}`descriptor ` and {ref}`fitting_net `, define the descriptor and the fitting net, respectively. +The two subsections, {ref}`descriptor ` and {ref}`fitting_net `, define the descriptor and the fitting net, respectively. The {ref}`type_map ` is optional, which provides the element names (but not necessarily same as the actual name of the element) of the corresponding atom types. A water model, as in this example, has two kinds of atoms. The atom types are internally recorded as integers, e.g., `0` for oxygen and `1` for hydrogen here. A mapping from the atom type to their names is provided by {ref}`type_map `. diff --git a/doc/model/train-energy-spin.md b/doc/model/train-energy-spin.md index 9f4e3cf04b..ec169892f2 100644 --- a/doc/model/train-energy-spin.md +++ b/doc/model/train-energy-spin.md @@ -11,9 +11,9 @@ keeping other sections the same as the normal energy model's input script. Note that when adding spin into the model, there will be some implicit modifications automatically done by the program: - In the TensorFlow backend, the `se_e2_a` descriptor will treat those atom types with spin as new (virtual) types, - and duplicate their corresponding selected numbers of neighbors ({ref}`sel `) from their real atom types. + and duplicate their corresponding selected numbers of neighbors ({ref}`sel `) from their real atom types. - In the PyTorch backend, if spin settings are added, all the types (with or without spin) will have their virtual types. - The `se_e2_a` descriptor will thus double the {ref}`sel ` list, + The `se_e2_a` descriptor will thus double the {ref}`sel ` list, while in other descriptors with mixed types (such as `dpa1` or `dpa2`), the sel number will not be changed for clarity. If you are using descriptors with mixed types, to achieve better performance, you should manually extend your sel number (maybe double) depending on the balance between performance and efficiency. diff --git a/doc/model/train-energy.md b/doc/model/train-energy.md index c1da1f4c1f..75d31d4670 100644 --- a/doc/model/train-energy.md +++ b/doc/model/train-energy.md @@ -79,7 +79,7 @@ Benefiting from the relative force loss, small forces can be fitted more accurat ## The fitting network -The construction of the fitting net is given by section {ref}`fitting_net ` +The construction of the fitting net is given by section {ref}`fitting_net ` ```json "fitting_net" : { @@ -89,9 +89,9 @@ The construction of the fitting net is given by section {ref}`fitting_net ` specifies the size of the fitting net. If two neighboring layers are of the same size, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. -- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. -- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. +- {ref}`neuron ` specifies the size of the fitting net. If two neighboring layers are of the same size, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. +- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. +- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. ## Loss diff --git a/doc/model/train-fitting-dos.md b/doc/model/train-fitting-dos.md index 4c4366a1e1..d04dbc669c 100644 --- a/doc/model/train-fitting-dos.md +++ b/doc/model/train-fitting-dos.md @@ -16,11 +16,11 @@ $deepmd_source_dir/examples/dos/input.json The training and validation data are also provided our examples. But note that **the data provided along with the examples are of limited amount, and should not be used to train a production model.** -Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-e2-a.md). To fit the `dos`, one needs to modify {ref}`model/fitting_net ` and {ref}`loss `. +Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-e2-a.md). To fit the `dos`, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. ## The fitting Network -The {ref}`fitting_net ` section tells DP which fitting net to use. +The {ref}`fitting_net ` section tells DP which fitting net to use. The JSON of `dos` type should be provided like diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index 4d5cb22707..c6b54c69ef 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -30,7 +30,7 @@ $deepmd_source_dir/examples/water_tensor/polar/polar_input_torch.json The training and validation data are also provided our examples. But note that **the data provided along with the examples are of limited amount, and should not be used to train a production model.** -Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-e2-a.md). To fit a tensor, one needs to modify {ref}`model/fitting_net ` and {ref}`loss `. +Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-e2-a.md). To fit a tensor, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. ## Theory @@ -72,7 +72,7 @@ The tensorial models can be used to calculate IR spectrum and Raman spectrum.[^1 ## The fitting Network -The {ref}`fitting_net ` section tells DP which fitting net to use. +The {ref}`fitting_net ` section tells DP which fitting net to use. ::::{tab-set} diff --git a/doc/model/train-hybrid.md b/doc/model/train-hybrid.md index c0a55d9eb5..1219d208a7 100644 --- a/doc/model/train-hybrid.md +++ b/doc/model/train-hybrid.md @@ -25,7 +25,7 @@ This way, one can set the different cutoff radii for different descriptors.[^1] ## Instructions -To use the descriptor in DeePMD-kit, one firstly set the {ref}`type ` to {ref}`hybrid `, then provide the definitions of the descriptors by the items in the `list`, +To use the descriptor in DeePMD-kit, one firstly set the {ref}`type ` to {ref}`hybrid `, then provide the definitions of the descriptors by the items in the `list`, ```json "descriptor" :{ diff --git a/doc/model/train-se-a-mask.md b/doc/model/train-se-a-mask.md index 6757fbefbd..69f344b138 100644 --- a/doc/model/train-se-a-mask.md +++ b/doc/model/train-se-a-mask.md @@ -29,7 +29,7 @@ A complete training input script of this example can be found in the directory. $deepmd_source_dir/examples/zinc_protein/zinc_se_a_mask.json ``` -The construction of the descriptor is given by section {ref}`descriptor `. An example of the descriptor is provided as follows +The construction of the descriptor is given by section {ref}`descriptor `. An example of the descriptor is provided as follows ```json "descriptor" :{ @@ -43,13 +43,13 @@ The construction of the descriptor is given by section {ref}`descriptor ` of the descriptor is set to `"se_a_mask"`. -- {ref}`sel ` gives the maximum number of atoms in input coordinates. It is a list, the length of which is the same as the number of atom types in the system, and `sel[i]` denotes the maximum number of atoms with type `i`. -- The {ref}`neuron ` specifies the size of the embedding net. From left to right the members denote the sizes of each hidden layer from the input end to the output end, respectively. If the outer layer is twice the size of the inner layer, then the inner layer is copied and concatenated, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. -- The {ref}`axis_neuron ` specifies the size of the submatrix of the embedding matrix, the axis matrix as explained in the [DeepPot-SE paper](https://arxiv.org/abs/1805.09003) -- If the option {ref}`type_one_side ` is set to `true`, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters. -- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. -- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. +- The {ref}`type ` of the descriptor is set to `"se_a_mask"`. +- {ref}`sel ` gives the maximum number of atoms in input coordinates. It is a list, the length of which is the same as the number of atom types in the system, and `sel[i]` denotes the maximum number of atoms with type `i`. +- The {ref}`neuron ` specifies the size of the embedding net. From left to right the members denote the sizes of each hidden layer from the input end to the output end, respectively. If the outer layer is twice the size of the inner layer, then the inner layer is copied and concatenated, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. +- The {ref}`axis_neuron ` specifies the size of the submatrix of the embedding matrix, the axis matrix as explained in the [DeepPot-SE paper](https://arxiv.org/abs/1805.09003) +- If the option {ref}`type_one_side ` is set to `true`, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters. +- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. +- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. To make the `aparam.npy` used for descriptor `se_a_mask`, two variables in `fitting_net` section are needed. @@ -63,9 +63,9 @@ To make the `aparam.npy` used for descriptor `se_a_mask`, two variables in `fitt } ``` -- `neuron`, `resnet_dt` and `seed` are the same as the {ref}`fitting_net ` section for fitting energy. -- {ref}`numb_aparam ` gives the dimesion of the `aparam.npy` file. In this example, it is set to 1 and stores the real/virtual sign of the atoms. For real/virtual atoms, the corresponding sign in `aparam.npy` is set to 1/0. -- {ref}`use_aparam_as_mask ` is set to `true` to use the `aparam.npy` as the mask of the atoms in the descriptor `se_a_mask`. +- `neuron`, `resnet_dt` and `seed` are the same as the {ref}`fitting_net ` section for fitting energy. +- {ref}`numb_aparam ` gives the dimesion of the `aparam.npy` file. In this example, it is set to 1 and stores the real/virtual sign of the atoms. For real/virtual atoms, the corresponding sign in `aparam.npy` is set to 1/0. +- {ref}`use_aparam_as_mask ` is set to `true` to use the `aparam.npy` as the mask of the atoms in the descriptor `se_a_mask`. Finally, to make a reasonable fitting task with `se_a_mask` descriptor for DP/MM simulations, the loss function with `se_a_mask` is designed to include the atomic forces difference in specific atoms of the input particles only. More details about the selection of the specific atoms can be found in paper [DP/MM](left to be filled). diff --git a/doc/model/train-se-atten.md b/doc/model/train-se-atten.md index 24950d9595..bebce78365 100644 --- a/doc/model/train-se-atten.md +++ b/doc/model/train-se-atten.md @@ -104,17 +104,17 @@ An example of the DPA-1 descriptor is provided as follows } ``` -- The {ref}`type ` of the descriptor is set to `"se_atten"`, which will use DPA-1 structures. -- {ref}`rcut ` is the cut-off radius for neighbor searching, and the {ref}`rcut_smth ` gives where the smoothing starts. -- **{ref}`sel `** gives the maximum possible number of neighbors in the cut-off radius. It is an int. Note that this number highly affects the efficiency of training, which we usually use less than 200. (We use 120 for training 56 elements in [OC2M dataset](https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md)) -- The {ref}`neuron ` specifies the size of the embedding net. From left to right the members denote the sizes of each hidden layer from the input end to the output end, respectively. If the outer layer is twice the size of the inner layer, then the inner layer is copied and concatenated, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. -- The {ref}`axis_neuron ` specifies the size of the submatrix of the embedding matrix, the axis matrix as explained in the [DeepPot-SE paper](https://arxiv.org/abs/1805.09003) -- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. -- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. -- {ref}`attn ` sets the length of a hidden vector during scale-dot attention computation. -- {ref}`attn_layer ` sets the number of layers in attention mechanism. -- {ref}`attn_mask ` determines whether to mask the diagonal in the attention weights and False is recommended. -- {ref}`attn_dotr ` determines whether to dot the relative coordinates on the attention weights as a gated scheme, True is recommended. +- The {ref}`type ` of the descriptor is set to `"se_atten"`, which will use DPA-1 structures. +- {ref}`rcut ` is the cut-off radius for neighbor searching, and the {ref}`rcut_smth ` gives where the smoothing starts. +- **{ref}`sel `** gives the maximum possible number of neighbors in the cut-off radius. It is an int. Note that this number highly affects the efficiency of training, which we usually use less than 200. (We use 120 for training 56 elements in [OC2M dataset](https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md)) +- The {ref}`neuron ` specifies the size of the embedding net. From left to right the members denote the sizes of each hidden layer from the input end to the output end, respectively. If the outer layer is twice the size of the inner layer, then the inner layer is copied and concatenated, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. +- The {ref}`axis_neuron ` specifies the size of the submatrix of the embedding matrix, the axis matrix as explained in the [DeepPot-SE paper](https://arxiv.org/abs/1805.09003) +- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. +- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. +- {ref}`attn ` sets the length of a hidden vector during scale-dot attention computation. +- {ref}`attn_layer ` sets the number of layers in attention mechanism. +- {ref}`attn_mask ` determines whether to mask the diagonal in the attention weights and False is recommended. +- {ref}`attn_dotr ` determines whether to dot the relative coordinates on the attention weights as a gated scheme, True is recommended. ### Descriptor `"se_atten_v2"` @@ -138,7 +138,7 @@ You can use descriptor `"se_atten_v2"` and do not need to set `tebd_input_mode` Practical evidence demonstrates that `"se_atten_v2"` offers better and more stable performance compared to `"se_atten"`. -Notice: Model compression for the `se_atten_v2` descriptor is exclusively designed for models with the training parameter {ref}`attn_layer ` set to 0. +Notice: Model compression for the `se_atten_v2` descriptor is exclusively designed for models with the training parameter {ref}`attn_layer ` set to 0. ### Fitting `"ener"` diff --git a/doc/model/train-se-e2-a.md b/doc/model/train-se-e2-a.md index 2412bbc64e..81b95399e0 100644 --- a/doc/model/train-se-e2-a.md +++ b/doc/model/train-se-e2-a.md @@ -70,7 +70,7 @@ $deepmd_source_dir/examples/water/se_e2_a/input.json With the training input script, data are also provided in the example directory. One may train the model with the DeePMD-kit from the directory. -The construction of the descriptor is given by section {ref}`descriptor `. An example of the descriptor is provided as follows +The construction of the descriptor is given by section {ref}`descriptor `. An example of the descriptor is provided as follows ```json "descriptor" :{ @@ -86,11 +86,11 @@ The construction of the descriptor is given by section {ref}`descriptor ` of the descriptor is set to `"se_e2_a"`. -- {ref}`rcut ` is the cut-off radius for neighbor searching, and the {ref}`rcut_smth ` gives where the smoothing starts. -- {ref}`sel ` gives the maximum possible number of neighbors in the cut-off radius. It is a list, the length of which is the same as the number of atom types in the system, and `sel[i]` denotes the maximum possible number of neighbors with type `i`. -- The {ref}`neuron ` specifies the size of the embedding net. From left to right the members denote the sizes of each hidden layer from the input end to the output end, respectively. If the outer layer is twice the size of the inner layer, then the inner layer is copied and concatenated, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. -- If the option {ref}`type_one_side ` is set to `true`, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters. -- The {ref}`axis_neuron ` specifies the size of the submatrix of the embedding matrix, the axis matrix as explained in the [DeepPot-SE paper](https://arxiv.org/abs/1805.09003) -- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. -- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. +- The {ref}`type ` of the descriptor is set to `"se_e2_a"`. +- {ref}`rcut ` is the cut-off radius for neighbor searching, and the {ref}`rcut_smth ` gives where the smoothing starts. +- {ref}`sel ` gives the maximum possible number of neighbors in the cut-off radius. It is a list, the length of which is the same as the number of atom types in the system, and `sel[i]` denotes the maximum possible number of neighbors with type `i`. +- The {ref}`neuron ` specifies the size of the embedding net. From left to right the members denote the sizes of each hidden layer from the input end to the output end, respectively. If the outer layer is twice the size of the inner layer, then the inner layer is copied and concatenated, then a [ResNet architecture](https://arxiv.org/abs/1512.03385) is built between them. +- If the option {ref}`type_one_side ` is set to `true`, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters. +- The {ref}`axis_neuron ` specifies the size of the submatrix of the embedding matrix, the axis matrix as explained in the [DeepPot-SE paper](https://arxiv.org/abs/1805.09003) +- If the option {ref}`resnet_dt ` is set to `true`, then a timestep is used in the ResNet. +- {ref}`seed ` gives the random seed that is used to generate random numbers when initializing the model parameters. diff --git a/doc/model/train-se-e2-r.md b/doc/model/train-se-e2-r.md index f427310196..316bde43b4 100644 --- a/doc/model/train-se-e2-r.md +++ b/doc/model/train-se-e2-r.md @@ -52,7 +52,7 @@ A complete training input script of this example can be found in the directory $deepmd_source_dir/examples/water/se_e2_r/input.json ``` -The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.md). The only difference lies in the {ref}`descriptor ` section +The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.md). The only difference lies in the {ref}`descriptor ` section ```json "descriptor": { @@ -68,4 +68,4 @@ The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.m }, ``` -The type of the descriptor is set by the key {ref}`type `. +The type of the descriptor is set by the key {ref}`type `. diff --git a/doc/model/train-se-e3-tebd.md b/doc/model/train-se-e3-tebd.md index 8b49b0c220..5935a8920a 100644 --- a/doc/model/train-se-e3-tebd.md +++ b/doc/model/train-se-e3-tebd.md @@ -56,7 +56,7 @@ A complete training input script of this example can be found in the directory $deepmd_source_dir/examples/water/se_e3_tebd/input.json ``` -The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.md). The only difference lies in the {ref}`descriptor ` section +The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.md). The only difference lies in the {ref}`descriptor ` section ```json "descriptor": { @@ -75,4 +75,4 @@ The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.m }, ``` -The type of the descriptor is set by the key {ref}`type `. +The type of the descriptor is set by the key {ref}`type `. diff --git a/doc/model/train-se-e3.md b/doc/model/train-se-e3.md index d650d72493..3d82c42c9e 100644 --- a/doc/model/train-se-e3.md +++ b/doc/model/train-se-e3.md @@ -48,7 +48,7 @@ A complete training input script of this example can be found in the directory $deepmd_source_dir/examples/water/se_e3/input.json ``` -The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.md). The only difference lies in the `descriptor ` section +The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.md). The only difference lies in the `descriptor ` section ```json "descriptor": { @@ -63,4 +63,4 @@ The training input script is very similar to that of [`se_e2_a`](train-se-e2-a.m }, ``` -The type of the descriptor is set by the key {ref}`type `. +The type of the descriptor is set by the key {ref}`type `. diff --git a/doc/train/finetuning.md b/doc/train/finetuning.md index 4fbe95b2fd..669d1319bd 100644 --- a/doc/train/finetuning.md +++ b/doc/train/finetuning.md @@ -36,7 +36,7 @@ The elements in the training dataset must be contained in the pre-trained datase The finetune procedure will inherit the model structures in `pretrained.pb`, and thus it will ignore the model parameters in `input.json`, -such as {ref}`descriptor `, {ref}`fitting_net `, +such as {ref}`descriptor `, {ref}`fitting_net `, {ref}`type_embedding ` and {ref}`type_map `. However, you can still set the `trainable` parameters in each part of `input.json` to control the training procedure. diff --git a/doc/train/gpu-limitations.md b/doc/train/gpu-limitations.md index 92577fd65c..44c9697dd4 100644 --- a/doc/train/gpu-limitations.md +++ b/doc/train/gpu-limitations.md @@ -5,5 +5,5 @@ If you use DeePMD-kit in a GPU environment, the acceptable value range of some v 1. The number of atom types of a given system must be less than 128. 2. The maximum distance between an atom and its neighbors must be less than 128. It can be controlled by setting the rcut value of training parameters. 3. Theoretically, the maximum number of atoms that a single GPU can accept is about 10,000,000. However, this value is limited by the GPU memory size currently, usually within 1000,000 atoms even in the model compression mode. -4. The total sel value of training parameters(in `model/descriptor` section) must be less than 4096. +4. The total sel value of training parameters(in `model[standard]/descriptor` section) must be less than 4096. 5. The size of the last layer of the embedding net must be less than 1024 during the model compression process. diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index 9be12e9fb8..d21feb2126 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -114,7 +114,7 @@ The section {ref}`mixed_precision ` specifies the mixe - {ref}`output_prec ` precision used in the output tensors, only `float32` is supported currently. - {ref}`compute_prec ` precision used in the computing tensors, only `float16` is supported currently. Note there are several limitations about mixed precision training: -- Only {ref}`se_e2_a ` type descriptor is supported by the mixed precision training workflow. +- Only {ref}`se_e2_a ` type descriptor is supported by the mixed precision training workflow. - The precision of the embedding net and the fitting net are forced to be set to `float32`. Other keys in the {ref}`training ` section are explained below: diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 294edec1d6..4112e09cff 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -6,8 +6,12 @@ from deepmd.common import ( make_default_mesh, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, ) @@ -20,6 +24,11 @@ GLOBAL_TF_FLOAT_PRECISION, tf, ) +if INSTALLED_JAX: + from deepmd.jax.common import to_jax_array as numpy_to_jax + from deepmd.jax.env import ( + jnp, + ) class ModelTest: @@ -62,3 +71,17 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: box=numpy_to_torch(box), ).items() } + + def eval_jax_model(self, jax_obj: Any, natoms, coords, atype, box) -> Any: + def assert_jax_array(arr): + assert isinstance(arr, jnp.ndarray) or arr is None + return arr + + return { + kk: to_numpy_array(assert_jax_array(vv)) + for kk, vv in jax_obj( + numpy_to_jax(coords), + numpy_to_jax(atype), + box=numpy_to_jax(box), + ).items() + } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 692e1287dc..78a2aac703 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -36,6 +37,12 @@ model_args, ) +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None + @parameterized( ( @@ -84,14 +91,20 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + jax_class = EnergyModelJAX args = model_args() + @property def skip_tf(self): return ( self.data["pair_exclude_types"] != [] or self.data["atom_exclude_types"] != [] ) + @property + def skip_jax(self): + return not INSTALLED_JAX + def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() @@ -99,6 +112,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) return cls(**data, **self.addtional_data) def setUp(self): @@ -168,6 +183,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: @@ -176,4 +200,6 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret["energy"].ravel(), ret["atom_energy"].ravel()) elif backend is self.RefBackend.TF: return (ret[0].ravel(), ret[1].ravel()) + elif backend is self.RefBackend.JAX: + return (ret["energy_redu"].ravel(), ret["energy"].ravel()) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index 488cc2f7ff..a3d696516a 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -245,13 +245,15 @@ def test_descriptor_block(self): des = DescrptBlockSeAtten( **dparams, ).to(env.DEVICE) - des.load_state_dict(torch.load(self.file_model_param)) + des.load_state_dict(torch.load(self.file_model_param, weights_only=True)) coord = self.coord atype = self.atype box = self.cell # handel type_embedding type_embedding = TypeEmbedNet(ntypes, 8, use_tebd_bias=True).to(env.DEVICE) - type_embedding.load_state_dict(torch.load(self.file_type_embed)) + type_embedding.load_state_dict( + torch.load(self.file_type_embed, weights_only=True) + ) ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') @@ -299,8 +301,8 @@ def test_descriptor(self): **dparams, ).to(env.DEVICE) target_dict = des.state_dict() - source_dict = torch.load(self.file_model_param) - type_embd_dict = torch.load(self.file_type_embed) + source_dict = torch.load(self.file_model_param, weights_only=True) + type_embd_dict = torch.load(self.file_type_embed, weights_only=True) target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1( target_dict, source_dict, diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index ac04bfc417..17d609a2f9 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -123,10 +123,10 @@ def test_descriptor(self): **dparams, ).to(env.DEVICE) target_dict = des.state_dict() - source_dict = torch.load(self.file_model_param) + source_dict = torch.load(self.file_model_param, weights_only=True) # type_embd of repformer is removed source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias") - type_embd_dict = torch.load(self.file_type_embed) + type_embd_dict = torch.load(self.file_type_embed, weights_only=True) target_dict = translate_type_embd_dicts_to_dpa2( target_dict, source_dict, diff --git a/source/tests/pt/model/test_saveload_dpa1.py b/source/tests/pt/model/test_saveload_dpa1.py index 3da06938b5..5b2b6cd583 100644 --- a/source/tests/pt/model/test_saveload_dpa1.py +++ b/source/tests/pt/model/test_saveload_dpa1.py @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) optimizer.zero_grad() if read: - wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + wrapper.load_state_dict( + torch.load(model_file, map_location=env.DEVICE, weights_only=True) + ) os.remove(model_file) else: torch.save(wrapper.state_dict(), model_file) diff --git a/source/tests/pt/model/test_saveload_se_e2_a.py b/source/tests/pt/model/test_saveload_se_e2_a.py index 56ea3283d9..d226f628bc 100644 --- a/source/tests/pt/model/test_saveload_se_e2_a.py +++ b/source/tests/pt/model/test_saveload_se_e2_a.py @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) optimizer.zero_grad() if read: - wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + wrapper.load_state_dict( + torch.load(model_file, map_location=env.DEVICE, weights_only=True) + ) os.remove(model_file) else: torch.save(wrapper.state_dict(), model_file)