diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 51de1893b3..53fdd9b71c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,7 @@ repos: exclude: ^source/3rdparty - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.9 + rev: v0.7.0 hooks: - id: ruff args: ["--fix"] @@ -60,7 +60,7 @@ repos: - id: blacken-docs # C++ - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v19.1.1 + rev: v19.1.2 hooks: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$) @@ -74,7 +74,7 @@ repos: exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt - rev: v3.9.0-1 + rev: v3.10.0-1 hooks: - id: shfmt # CMake diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index c29a76b3f1..6307b19f41 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy +import math from typing import ( Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( NativeOP, + to_numpy_array, ) from deepmd.dpmodel.output_def import ( FittingOutputDef, @@ -172,17 +174,18 @@ def forward_common_atomic( ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. """ + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) # exclude neighbors in the nlist - nlist = np.where(pair_mask == 1, nlist, -1) + nlist = xp.where(pair_mask == 1, nlist, -1) ext_atom_mask = self.make_atom_mask(extended_atype) ret_dict = self.forward_atomic( extended_coord, - np.where(ext_atom_mask, extended_atype, 0), + xp.where(ext_atom_mask, extended_atype, 0), nlist, mapping=mapping, fparam=fparam, @@ -191,13 +194,13 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) + atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) if self.atom_excl is not None: atom_mask *= self.atom_excl.build_type_exclude_mask(atype) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape - out_shape2 = np.prod(out_shape[2:]) + out_shape2 = math.prod(out_shape[2:]) ret_dict[kk] = ( ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] @@ -232,14 +235,15 @@ def serialize(self) -> dict: "rcond": self.rcond, "preset_out_bias": self.preset_out_bias, "@variables": { - "out_bias": self.out_bias, - "out_std": self.out_std, + "out_bias": to_numpy_array(self.out_bias), + "out_std": to_numpy_array(self.out_std), }, } @classmethod def deserialize(cls, data: dict) -> "BaseAtomicModel": - data = copy.deepcopy(data) + # do not deep copy Descriptor and Fitting class + data = data.copy() variables = data.pop("@variables") obj = cls(**data) for kk in variables.keys(): diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 7e576eb484..fe049021fe 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -169,14 +169,20 @@ def serialize(self) -> dict: ) return dd + # for subclass overriden + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") - descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) - fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor")) + fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting")) data["descriptor"] = descriptor_obj data["fitting"] = fitting_obj obj = super().deserialize(data) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 8cdb7e1f25..dc90f10da7 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -3,6 +3,7 @@ Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.atomic_model.base_atomic_model import ( @@ -75,7 +76,8 @@ def __init__( else: self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT - self.reverse_precision_dict = RESERVED_PRECISON_DICT + # not supported by flax + # self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION @@ -253,9 +255,7 @@ def input_type_cast( str, ]: """Cast the input data to global float type.""" - input_prec = self.reverse_precision_dict[ - self.precision_dict[coord.dtype.name] - ] + input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]] ### ### type checking would not pass jit, convert to coord prec anyway ### @@ -264,10 +264,7 @@ def input_type_cast( for vv in [box, fparam, aparam] ] box, fparam, aparam = _lst - if ( - input_prec - == self.reverse_precision_dict[self.global_np_float_precision] - ): + if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]: return coord, box, fparam, aparam, input_prec else: pp = self.global_np_float_precision @@ -286,8 +283,7 @@ def output_type_cast( ) -> dict[str, np.ndarray]: """Convert the model output to the input prec.""" do_cast = ( - input_prec - != self.reverse_precision_dict[self.global_np_float_precision] + input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision] ) pp = self.precision_dict[input_prec] odef = self.model_output_def() @@ -366,6 +362,7 @@ def _format_nlist( nnei: int, extra_nlist_sort: bool = False, ): + xp = array_api_compat.array_namespace(extended_coord, nlist) n_nf, n_nloc, n_nnei = nlist.shape extended_coord = extended_coord.reshape([n_nf, -1, 3]) nall = extended_coord.shape[1] @@ -373,10 +370,10 @@ def _format_nlist( if n_nnei < nnei: # make a copy before revise - ret = np.concatenate( + ret = xp.concat( [ nlist, - -1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + -1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), ], axis=-1, ) @@ -385,16 +382,16 @@ def _format_nlist( n_nf, n_nloc, n_nnei = nlist.shape # make a copy before revise m_real_nei = nlist >= 0 - ret = np.where(m_real_nei, nlist, 0) + ret = xp.where(m_real_nei, nlist, 0) coord0 = extended_coord[:, :n_nloc, :] index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) - coord1 = np.take_along_axis(extended_coord, index, axis=1) + coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr = np.where(m_real_nei, rr, float("inf")) - rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1) - ret = np.take_along_axis(ret, ret_mapping, axis=2) - ret = np.where(rr > rcut, -1, ret) + rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = xp.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) + ret = xp.take_along_axis(ret, ret_mapping, axis=2) + ret = xp.where(rr > rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 43c275b1be..928c33f3bd 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -23,6 +24,7 @@ def fit_output_to_model_output( the model output. """ + xp = array_api_compat.get_namespace(coord_ext) model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] @@ -31,7 +33,7 @@ def fit_output_to_model_output( if vdef.reducible: kk_redu = get_reduce_name(kk) # cast to energy prec brefore reduction - model_ret[kk_redu] = np.sum( + model_ret[kk_redu] = xp.sum( vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis ) if vdef.r_differentiable: diff --git a/deepmd/jax/atomic_model/__init__.py b/deepmd/jax/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py new file mode 100644 index 0000000000..90920879c2 --- /dev/null +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + + +def base_atomic_model_set_attr(name, value): + if name in {"out_bias", "out_std"}: + value = to_jax_array(value) + elif name == "pair_excl" and value is not None: + value = PairExcludeMask(value.ntypes, value.exclude_types) + elif name == "atom_excl" and value is not None: + value = AtomExcludeMask(value.ntypes, value.exclude_types) + return value diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py new file mode 100644 index 0000000000..077209e29a --- /dev/null +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) + + +@flax_module +class DPAtomicModel(DPAtomicModelDP): + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 6ceb116d85..ed59493268 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -1 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", + "DescrptDPA1", +] diff --git a/deepmd/jax/descriptor/base_descriptor.py b/deepmd/jax/descriptor/base_descriptor.py new file mode 100644 index 0000000000..7dec3cd6d4 --- /dev/null +++ b/deepmd/jax/descriptor/base_descriptor.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) +from deepmd.jax.env import ( + jnp, +) + +BaseDescriptor = make_base_descriptor(jnp.ndarray) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index a9b0404970..0528e4bb93 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -16,6 +16,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") @flax_module class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index a60a4e9af1..d1a6e9a8d9 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -8,6 +8,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -16,6 +19,8 @@ ) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") @flax_module class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index 6ceb116d85..e72314dcab 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.fitting.fitting import ( + DOSFittingNet, + EnergyFittingNet, +) + +__all__ = [ + "EnergyFittingNet", + "DOSFittingNet", +] diff --git a/deepmd/jax/fitting/base_fitting.py b/deepmd/jax/fitting/base_fitting.py new file mode 100644 index 0000000000..fd9f3a416d --- /dev/null +++ b/deepmd/jax/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.fitting.make_base_fitting import ( + make_base_fitting, +) +from deepmd.jax.env import ( + jnp, +) + +BaseFitting = make_base_fitting(jnp.ndarray) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 284213c70a..f979db4d41 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) from deepmd.jax.utils.exclude_mask import ( AtomExcludeMask, ) @@ -33,6 +36,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: return value +@BaseFitting.register("ener") @flax_module class EnergyFittingNet(EnergyFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: @@ -40,6 +44,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py new file mode 100644 index 0000000000..05a60c4ffe --- /dev/null +++ b/deepmd/jax/model/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .ener_model import ( + EnergyModel, +) + +__all__ = ["EnergyModel"] diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py new file mode 100644 index 0000000000..fee4855da3 --- /dev/null +++ b/deepmd/jax/model/base_model.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.model.base_model import ( + make_base_model, +) + +BaseModel = make_base_model() diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py new file mode 100644 index 0000000000..79c5a29e88 --- /dev/null +++ b/deepmd/jax/model/ener_model.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.model import EnergyModel as EnergyModelDP +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.common import ( + flax_module, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +@BaseModel.register("ener") +@flax_module +class EnergyModel(EnergyModelDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = DPAtomicModel.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py new file mode 100644 index 0000000000..7fa3efda6e --- /dev/null +++ b/deepmd/jax/model/model.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def get_standard_model(data: dict): + """Get a Model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + data = deepcopy(data) + descriptor_type = data["descriptor"].pop("type") + data["descriptor"]["type_map"] = data["type_map"] + fitting_type = data["fitting_net"].pop("type") + data["fitting_net"]["type_map"] = data["type_map"] + descriptor = BaseDescriptor.get_class_by_type(descriptor_type)( + **data["descriptor"], + ) + fitting = BaseFitting.get_class_by_type(fitting_type)( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + return BaseModel.get_class_by_type(fitting_type)( + descriptor=descriptor, + fitting=fitting, + type_map=data["type_map"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=data.get("pair_exclude_types", []), + ) + + +def get_model(data: dict): + """Get a model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + model_type = data.get("type", "standard") + if model_type == "standard": + if "spin" in data: + raise NotImplementedError("Spin model is not implemented yet.") + else: + return get_standard_model(data) + else: + return BaseModel.get_class_by_type(model_type).get_model(data) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7c8a95c5e7..71f81b0a12 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -283,7 +283,9 @@ def train( # update init_model or init_frz_model config if necessary if (init_model is not None or init_frz_model is not None) and use_pretrain_script: if init_model is not None: - init_state_dict = torch.load(init_model, map_location=DEVICE) + init_state_dict = torch.load( + init_model, map_location=DEVICE, weights_only=True + ) if "model" in init_state_dict: init_state_dict = init_state_dict["model"] config["model"] = init_state_dict["_extra_state"]["model_params"] @@ -380,7 +382,9 @@ def change_bias( output: Optional[str] = None, ): if input_file.endswith(".pt"): - old_state_dict = torch.load(input_file, map_location=env.DEVICE) + old_state_dict = torch.load( + input_file, map_location=env.DEVICE, weights_only=True + ) model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) model_params = model_state_dict["_extra_state"]["model_params"] elif input_file.endswith(".pth"): diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 0a77a38135..acf985974c 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -103,7 +103,9 @@ def __init__( self.output_def = output_def self.model_path = model_file if str(self.model_path).endswith(".pt"): - state_dict = torch.load(model_file, map_location=env.DEVICE) + state_dict = torch.load( + model_file, map_location=env.DEVICE, weights_only=True + ) if "model" in state_dict: state_dict = state_dict["model"] self.input_param = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index dfb7abdb21..b3d120cbc4 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -34,7 +34,7 @@ def __init__( - config: The Dict-like configuration with training options. """ # Model - state_dict = torch.load(model_ckpt, map_location=DEVICE) + state_dict = torch.load(model_ckpt, map_location=DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_params = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 10e841682a..0f7c030a84 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -400,7 +400,9 @@ def get_lr(lr_params): optimizer_state_dict = None if resuming: log.info(f"Resuming from {resume_model}.") - state_dict = torch.load(resume_model, map_location=DEVICE) + state_dict = torch.load( + resume_model, map_location=DEVICE, weights_only=True + ) if "model" in state_dict: optimizer_state_dict = ( state_dict["optimizer"] if finetune_model is None else None diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 2dd2230b54..96a420bf6a 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -136,7 +136,7 @@ def get_finetune_rules( Fine-tuning rules in a dict format, with `model_branch`: FinetuneRuleItem pairs. """ multi_task = "model_dict" in model_config - state_dict = torch.load(finetune_model, map_location=env.DEVICE) + state_dict = torch.load(finetune_model, map_location=env.DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] last_model_params = state_dict["_extra_state"]["model_params"] diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index a4f81a23a5..1060b40ce1 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -457,7 +457,7 @@ def extend_coord_with_ghosts( xyz = xyz.view(-1, 3) xyz = xyz.to(device=device, non_blocking=True) # ns x 3 - shift_idx = xyz[torch.argsort(torch.norm(xyz, dim=1))] + shift_idx = xyz[torch.argsort(torch.linalg.norm(xyz, dim=-1))] ns, _ = shift_idx.shape nall = ns * nloc # nf x ns x 3 diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index aab6d100a5..1c6ea096aa 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -33,7 +33,7 @@ def serialize_from_file(model_file: str) -> dict: model = get_model(model_def_script) model.load_state_dict(saved_model.state_dict()) elif model_file.endswith(".pt"): - state_dict = torch.load(model_file, map_location="cpu") + state_dict = torch.load(model_file, map_location="cpu", weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] model_def_script = state_dict["_extra_state"]["model_params"] diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 4079a8d424..3f65375865 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -198,7 +198,7 @@ Enable compilation optimization for the native machine's CPU type. Do not enable **Type**: string -Control high (double) or low (float) precision of training. +Additional CMake arguments. ::: :::{envvar} FLAGS diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 294edec1d6..4112e09cff 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -6,8 +6,12 @@ from deepmd.common import ( make_default_mesh, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, ) @@ -20,6 +24,11 @@ GLOBAL_TF_FLOAT_PRECISION, tf, ) +if INSTALLED_JAX: + from deepmd.jax.common import to_jax_array as numpy_to_jax + from deepmd.jax.env import ( + jnp, + ) class ModelTest: @@ -62,3 +71,17 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: box=numpy_to_torch(box), ).items() } + + def eval_jax_model(self, jax_obj: Any, natoms, coords, atype, box) -> Any: + def assert_jax_array(arr): + assert isinstance(arr, jnp.ndarray) or arr is None + return arr + + return { + kk: to_numpy_array(assert_jax_array(vv)) + for kk, vv in jax_obj( + numpy_to_jax(coords), + numpy_to_jax(atype), + box=numpy_to_jax(box), + ).items() + } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 692e1287dc..78a2aac703 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -36,6 +37,12 @@ model_args, ) +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None + @parameterized( ( @@ -84,14 +91,20 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + jax_class = EnergyModelJAX args = model_args() + @property def skip_tf(self): return ( self.data["pair_exclude_types"] != [] or self.data["atom_exclude_types"] != [] ) + @property + def skip_jax(self): + return not INSTALLED_JAX + def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() @@ -99,6 +112,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) return cls(**data, **self.addtional_data) def setUp(self): @@ -168,6 +183,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: @@ -176,4 +200,6 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret["energy"].ravel(), ret["atom_energy"].ravel()) elif backend is self.RefBackend.TF: return (ret[0].ravel(), ret[1].ravel()) + elif backend is self.RefBackend.JAX: + return (ret["energy_redu"].ravel(), ret["energy"].ravel()) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/pt/model/test_descriptor_dpa1.py b/source/tests/pt/model/test_descriptor_dpa1.py index 488cc2f7ff..a3d696516a 100644 --- a/source/tests/pt/model/test_descriptor_dpa1.py +++ b/source/tests/pt/model/test_descriptor_dpa1.py @@ -245,13 +245,15 @@ def test_descriptor_block(self): des = DescrptBlockSeAtten( **dparams, ).to(env.DEVICE) - des.load_state_dict(torch.load(self.file_model_param)) + des.load_state_dict(torch.load(self.file_model_param, weights_only=True)) coord = self.coord atype = self.atype box = self.cell # handel type_embedding type_embedding = TypeEmbedNet(ntypes, 8, use_tebd_bias=True).to(env.DEVICE) - type_embedding.load_state_dict(torch.load(self.file_type_embed)) + type_embedding.load_state_dict( + torch.load(self.file_type_embed, weights_only=True) + ) ## to save model parameters # torch.save(des.state_dict(), 'model_weights.pth') @@ -299,8 +301,8 @@ def test_descriptor(self): **dparams, ).to(env.DEVICE) target_dict = des.state_dict() - source_dict = torch.load(self.file_model_param) - type_embd_dict = torch.load(self.file_type_embed) + source_dict = torch.load(self.file_model_param, weights_only=True) + type_embd_dict = torch.load(self.file_type_embed, weights_only=True) target_dict = translate_se_atten_and_type_embd_dicts_to_dpa1( target_dict, source_dict, diff --git a/source/tests/pt/model/test_descriptor_dpa2.py b/source/tests/pt/model/test_descriptor_dpa2.py index ac04bfc417..17d609a2f9 100644 --- a/source/tests/pt/model/test_descriptor_dpa2.py +++ b/source/tests/pt/model/test_descriptor_dpa2.py @@ -123,10 +123,10 @@ def test_descriptor(self): **dparams, ).to(env.DEVICE) target_dict = des.state_dict() - source_dict = torch.load(self.file_model_param) + source_dict = torch.load(self.file_model_param, weights_only=True) # type_embd of repformer is removed source_dict.pop("type_embedding.embedding.embedding_net.layers.0.bias") - type_embd_dict = torch.load(self.file_type_embed) + type_embd_dict = torch.load(self.file_type_embed, weights_only=True) target_dict = translate_type_embd_dicts_to_dpa2( target_dict, source_dict, diff --git a/source/tests/pt/model/test_saveload_dpa1.py b/source/tests/pt/model/test_saveload_dpa1.py index 3da06938b5..5b2b6cd583 100644 --- a/source/tests/pt/model/test_saveload_dpa1.py +++ b/source/tests/pt/model/test_saveload_dpa1.py @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) optimizer.zero_grad() if read: - wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + wrapper.load_state_dict( + torch.load(model_file, map_location=env.DEVICE, weights_only=True) + ) os.remove(model_file) else: torch.save(wrapper.state_dict(), model_file) diff --git a/source/tests/pt/model/test_saveload_se_e2_a.py b/source/tests/pt/model/test_saveload_se_e2_a.py index 56ea3283d9..d226f628bc 100644 --- a/source/tests/pt/model/test_saveload_se_e2_a.py +++ b/source/tests/pt/model/test_saveload_se_e2_a.py @@ -85,7 +85,9 @@ def get_model_result(self, read=False, model_file="tmp_model.pt"): optimizer = torch.optim.Adam(wrapper.parameters(), lr=self.start_lr) optimizer.zero_grad() if read: - wrapper.load_state_dict(torch.load(model_file, map_location=env.DEVICE)) + wrapper.load_state_dict( + torch.load(model_file, map_location=env.DEVICE, weights_only=True) + ) os.remove(model_file) else: torch.save(wrapper.state_dict(), model_file)