diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 7131f4d534..cfb0936bda 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -34,11 +34,11 @@ class JAXBackend(Backend): features: ClassVar[Backend.Feature] = ( Backend.Feature.IO | Backend.Feature.ENTRY_POINT - # | Backend.Feature.DEEP_EVAL + | Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".jax"] + suffixes: ClassVar[list[str]] = [".hlo", ".jax"] """The suffixes of the backend.""" def is_available(self) -> bool: @@ -71,7 +71,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]: type[DeepEvalBackend] The Deep Eval backend of the backend. """ - raise NotImplementedError + from deepmd.jax.infer.deep_eval import ( + DeepEval, + ) + + return DeepEval @property def neighbor_stat(self) -> type["NeighborStat"]: diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index feebe57af7..6c0efb94d4 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -555,7 +555,7 @@ def call( coord_ext, atype_ext, nlist, self.davg, self.dstd ) nf, nloc, nnei, _ = rr.shape - sec = xp.asarray(self.sel_cumsum) + sec = self.sel_cumsum ng = self.neuron[-1] gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index e36182e712..b6379573e1 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + Callable, Optional, ) @@ -39,6 +40,95 @@ ) +def model_call_from_call_lower( + *, # enforce keyword-only arguments + call_lower: Callable[ + [ + np.ndarray, + np.ndarray, + np.ndarray, + Optional[np.ndarray], + Optional[np.ndarray], + bool, + ], + dict[str, np.ndarray], + ], + rcut: float, + sel: list[int], + mixed_types: bool, + model_output_def: ModelOutputDef, + coord: np.ndarray, + atype: np.ndarray, + box: Optional[np.ndarray] = None, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + do_atomic_virial: bool = False, +): + """Return model prediction from lower interface. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,np.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + nframes, nloc = atype.shape[:2] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if bb is not None: + coord_normalized = normalize_coord( + cc.reshape(nframes, nloc, 3), + bb.reshape(nframes, 3, 3), + ) + else: + coord_normalized = cc.copy() + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + model_predict_lower = call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + ) + model_predict = communicate_extended_output( + model_predict_lower, + model_output_def, + mapping, + do_atomic_virial=do_atomic_virial, + ) + return model_predict + + def make_model(T_AtomicModel: type[BaseAtomicModel]): """Make a model as a derived class of an atomic model. @@ -130,45 +220,23 @@ def call( The keys are defined by the `ModelOutputDef`. """ - nframes, nloc = atype.shape[:2] cc, bb, fp, ap, input_prec = self.input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam - if bb is not None: - coord_normalized = normalize_coord( - cc.reshape(nframes, nloc, 3), - bb.reshape(nframes, 3, 3), - ) - else: - coord_normalized = cc.copy() - extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, bb, self.get_rcut() - ) - nlist = build_neighbor_list( - extended_coord, - extended_atype, - nloc, - self.get_rcut(), - self.get_sel(), - distinguish_types=not self.mixed_types(), - ) - extended_coord = extended_coord.reshape(nframes, -1, 3) - model_predict_lower = self.call_lower( - extended_coord, - extended_atype, - nlist, - mapping, + model_predict = model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=cc, + atype=atype, + box=bb, fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, ) - model_predict = communicate_extended_output( - model_predict_lower, - self.model_output_def(), - mapping, - do_atomic_virial=do_atomic_virial, - ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/dpmodel/utils/serialization.py b/deepmd/dpmodel/utils/serialization.py index 5e70ec6769..37702cc9f0 100644 --- a/deepmd/dpmodel/utils/serialization.py +++ b/deepmd/dpmodel/utils/serialization.py @@ -90,7 +90,7 @@ def save_dp_model(filename: str, model_dict: dict) -> None: # use UTC+0 time "time": str(datetime.datetime.now(tz=datetime.timezone.utc)), } - if filename_extension == ".dp": + if filename_extension in (".dp", ".hlo"): variable_counter = Counter() with h5py.File(filename, "w") as f: model_dict = traverse_model_dict( @@ -141,7 +141,7 @@ def load_dp_model(filename: str) -> dict: The loaded model dict, including meta information. """ filename_extension = Path(filename).suffix - if filename_extension == ".dp": + if filename_extension in {".dp", ".hlo"}: with h5py.File(filename, "r") as f: model_dict = json.loads(f.attrs["json"]) model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy()) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index ee11e17125..1b90433b00 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -8,6 +8,7 @@ from flax import ( nnx, ) +from jax import export as jax_export jax.config.update("jax_enable_x64", True) # jax.config.update("jax_debug_nans", True) @@ -16,4 +17,5 @@ "jax", "jnp", "nnx", + "jax_export", ] diff --git a/deepmd/jax/infer/__init__.py b/deepmd/jax/infer/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/infer/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py new file mode 100644 index 0000000000..76f044a327 --- /dev/null +++ b/deepmd/jax/infer/deep_eval.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Union, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.output_def import ( + ModelOutputDef, + OutputVariableCategory, + OutputVariableDef, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.infer.deep_dipole import ( + DeepDipole, +) +from deepmd.infer.deep_dos import ( + DeepDOS, +) +from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper +from deepmd.infer.deep_eval import ( + DeepEvalBackend, +) +from deepmd.infer.deep_polar import ( + DeepPolar, +) +from deepmd.infer.deep_pot import ( + DeepPot, +) +from deepmd.infer.deep_wfc import ( + DeepWFC, +) +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.model.hlo import ( + HLO, +) +from deepmd.jax.utils.auto_batch_size import ( + AutoBatchSize, +) + +if TYPE_CHECKING: + import ase.neighborlist + + +class DeepEval(DeepEvalBackend): + """NumPy backend implementation of DeepEval. + + Parameters + ---------- + model_file : str + The name of the frozen model file. + output_def : ModelOutputDef + The output definition of the model. + *args : list + Positional arguments. + auto_batch_size : bool or int or AutoBatchSize, default: True + If True, automatic batch size will be used. If int, it will be used + as the initial batch size. + neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional + The ASE neighbor list class to produce the neighbor list. If None, the + neighbor list will be built natively in the model. + **kwargs : dict + Keyword arguments. + """ + + def __init__( + self, + model_file: str, + output_def: ModelOutputDef, + *args: Any, + auto_batch_size: Union[bool, int, AutoBatchSize] = True, + neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, + **kwargs: Any, + ): + self.output_def = output_def + self.model_path = model_file + + model_data = load_dp_model(model_file) + self.dp = HLO( + stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + model_def_script=model_data["model_def_script"], + **model_data["constants"], + ) + self.rcut = self.dp.get_rcut() + self.type_map = self.dp.get_type_map() + if isinstance(auto_batch_size, bool): + if auto_batch_size: + self.auto_batch_size = AutoBatchSize() + else: + self.auto_batch_size = None + elif isinstance(auto_batch_size, int): + self.auto_batch_size = AutoBatchSize(auto_batch_size) + elif isinstance(auto_batch_size, AutoBatchSize): + self.auto_batch_size = auto_batch_size + else: + raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize") + + def get_rcut(self) -> float: + """Get the cutoff radius of this model.""" + return self.rcut + + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return len(self.type_map) + + def get_type_map(self) -> list[str]: + """Get the type map (element name of the atom types) of this model.""" + return self.type_map + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return self.dp.get_dim_fparam() + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return self.dp.get_dim_aparam() + + @property + def model_type(self) -> type["DeepEvalWrapper"]: + """The evaluator of the model type.""" + model_output_type = self.dp.model_output_type() + if "energy" in model_output_type: + return DeepPot + elif "dos" in model_output_type: + return DeepDOS + elif "dipole" in model_output_type: + return DeepDipole + elif "polar" in model_output_type: + return DeepPolar + elif "wfc" in model_output_type: + return DeepWFC + else: + raise RuntimeError("Unknown model type") + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.dp.get_sel_type() + + def get_numb_dos(self) -> int: + """Get the number of DOS.""" + return 0 + + def get_has_efield(self): + """Check if the model has efield.""" + return False + + def get_ntypes_spin(self): + """Get the number of spin atom types of this model.""" + return 0 + + def eval( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + atomic: bool = False, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + **kwargs: Any, + ) -> dict[str, np.ndarray]: + """Evaluate the energy, force and virial by using this DP. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 3 x 3 + atom_types + The atom types + The list should contain natoms ints + atomic + Calculate the atomic energy and virial + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + **kwargs + Other parameters + + Returns + ------- + output_dict : dict + The output of the evaluation. The keys are the names of the output + variables, and the values are the corresponding output arrays. + """ + if fparam is not None or aparam is not None: + raise NotImplementedError + # convert all of the input to numpy array + atom_types = np.array(atom_types, dtype=np.int32) + coords = np.array(coords) + if cells is not None: + cells = np.array(cells) + natoms, numb_test = self._get_natoms_and_nframes( + coords, atom_types, len(atom_types.shape) > 1 + ) + request_defs = self._get_request_defs(atomic) + out = self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, request_defs + ) + return dict( + zip( + [x.name for x in request_defs], + out, + ) + ) + + def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: + """Get the requested output definitions. + + When atomic is True, all output_def are requested. + When atomic is False, only energy (tensor), force, and virial + are requested. + + Parameters + ---------- + atomic : bool + Whether to request the atomic output. + + Returns + ------- + list[OutputVariableDef] + The requested output definitions. + """ + if atomic: + return list(self.output_def.var_defs.values()) + else: + return [ + x + for x in self.output_def.var_defs.values() + if x.category + in ( + OutputVariableCategory.REDU, + OutputVariableCategory.DERV_R, + OutputVariableCategory.DERV_C_REDU, + ) + ] + + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size. + + Parameters + ---------- + inner_func : Callable + the method to be wrapped + numb_test : int + number of tests + natoms : int + number of atoms + + Returns + ------- + Callable + the wrapper + """ + if self.auto_batch_size is not None: + + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all( + inner_func, numb_test, natoms, *args, **kwargs + ) + + else: + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes( + self, + coords: np.ndarray, + atom_types: np.ndarray, + mixed_type: bool = False, + ) -> tuple[int, int]: + if mixed_type: + natoms = len(atom_types[0]) + else: + natoms = len(atom_types) + if natoms == 0: + assert coords.size == 0 + else: + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes + + def _eval_model( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + request_defs: list[OutputVariableDef], + ): + model = self.dp + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = coords.reshape([-1, natoms, 3]) + type_input = atom_types + if cells is not None: + box_input = cells.reshape([-1, 3, 3]) + else: + box_input = None + + do_atomic_virial = any( + x.category == OutputVariableCategory.DERV_C_REDU for x in request_defs + ) + batch_output = model( + to_jax_array(coord_input), + to_jax_array(type_input), + box=to_jax_array(box_input), + do_atomic_virial=do_atomic_virial, + ) + if isinstance(batch_output, tuple): + batch_output = batch_output[0] + for kk, vv in batch_output.items(): + batch_output[kk] = to_numpy_array(vv) + + results = [] + for odef in request_defs: + # it seems not doing conversion + # dp_name = self._OUTDEF_DP2BACKEND[odef.name] + dp_name = odef.name + if dp_name in batch_output: + shape = self._get_output_shape(odef, nframes, natoms) + if batch_output[dp_name] is not None: + out = batch_output[dp_name].reshape(shape) + else: + out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + results.append(out) + else: + shape = self._get_output_shape(odef, nframes, natoms) + results.append( + np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + ) # this is kinda hacky + return tuple(results) + + def _get_output_shape(self, odef, nframes, natoms): + if odef.category == OutputVariableCategory.DERV_C_REDU: + # virial + return [nframes, *odef.shape[:-1], 9] + elif odef.category == OutputVariableCategory.REDU: + # energy + return [nframes, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_C: + # atom_virial + return [nframes, *odef.shape[:-1], natoms, 9] + elif odef.category == OutputVariableCategory.DERV_R: + # force + return [nframes, *odef.shape[:-1], natoms, 3] + elif odef.category == OutputVariableCategory.OUT: + # atom_energy, atom_tensor + return [nframes, natoms, *odef.shape, 1] + else: + raise RuntimeError("unknown category") + + def get_model_def_script(self) -> dict: + """Get model definition script.""" + return json.loads(self.dp.get_model_def_script()) diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py new file mode 100644 index 0000000000..010e3d7a5e --- /dev/null +++ b/deepmd/jax/model/hlo.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +from deepmd.dpmodel.model.make_model import ( + model_call_from_call_lower, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) +from deepmd.jax.env import ( + jax_export, + jnp, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +OUTPUT_DEFS = { + "energy": OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + "mask": OutputVariableDef( + "mask", + shape=[1], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), +} + + +class HLO(BaseModel): + def __init__( + self, + stablehlo, + model_def_script, + type_map, + rcut, + dim_fparam, + dim_aparam, + sel_type, + is_aparam_nall, + model_output_type, + mixed_types, + min_nbor_dist, + sel, + ) -> None: + self._call_lower = jax_export.deserialize(stablehlo).call + self.stablehlo = stablehlo + self.type_map = type_map + self.rcut = rcut + self.dim_fparam = dim_fparam + self.dim_aparam = dim_aparam + self.sel_type = sel_type + self._is_aparam_nall = is_aparam_nall + self._model_output_type = model_output_type + self._mixed_types = mixed_types + self.min_nbor_dist = min_nbor_dist + self.sel = sel + self.model_def_script = model_def_script + + def __call__( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Any: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,np.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return self.call(coord, atype, box, fparam, aparam, do_atomic_virial) + + def call( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,np.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def model_output_def(self): + return ModelOutputDef( + FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) + ) + + def call_lower( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + return self._call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + do_atomic_virial, + ) + + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.type_map + + def get_rcut(self): + """Get the cut-off radius.""" + return self.rcut + + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.dim_fparam + + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.dim_aparam + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.sel_type + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self._is_aparam_nall + + def model_output_type(self) -> list[str]: + """Get the output type for the model.""" + return self._model_output_type + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented") + + @classmethod + def deserialize(cls, data: dict) -> "BaseModel": + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized model + """ + raise NotImplementedError("Not implemented") + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.model_def_script + + def get_min_nbor_dist(self) -> Optional[float]: + """Get the minimum distance between two atoms.""" + return self.min_nbor_dist + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.nsel + + def get_sel(self) -> list[int]: + return self.sel + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return sum(self.sel) + + def mixed_types(self) -> bool: + return self._mixed_types + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + raise NotImplementedError("Not implemented") + + @classmethod + def get_model(cls, model_params: dict) -> "BaseModel": + """Get the model by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + model_params : dict + The model parameters + + Returns + ------- + BaseBaseModel + The model + """ + raise NotImplementedError("Not implemented") diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 43070f8a07..fcfcc8a610 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -3,10 +3,17 @@ Path, ) +import numpy as np import orbax.checkpoint as ocp +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, +) from deepmd.jax.env import ( jax, + jax_export, + jnp, nnx, ) from deepmd.jax.model.model import ( @@ -39,6 +46,44 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_def_script=ocp.args.JsonSave(model_def_script), ), ) + elif model_file.endswith(".hlo"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_lower + + nf, nloc, nghost, nfp, nap = jax_export.symbolic_shape( + "nf, nloc, nghost, nfp, nap" + ) + exported = jax_export.export(jax.jit(call_lower))( + jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, nfp), jnp.float64) + if model.get_dim_fparam() + else None, # fparam + jax.ShapeDtypeStruct((nf, nap), jnp.float64) + if model.get_dim_aparam() + else None, # aparam + False, # do_atomic_virial + ) + serialized: bytearray = exported.serialize() + data = data.copy() + data.setdefault("@variables", {}) + data["@variables"]["stablehlo"] = np.void(serialized) + data["constants"] = { + "type_map": model.get_type_map(), + "rcut": model.get_rcut(), + "dim_fparam": model.get_dim_fparam(), + "dim_aparam": model.get_dim_aparam(), + "sel_type": model.get_sel_type(), + "is_aparam_nall": model.is_aparam_nall(), + "model_output_type": model.model_output_type(), + "mixed_types": model.mixed_types(), + "min_nbor_dist": model.get_min_nbor_dist(), + "sel": model.get_sel(), + } + save_dp_model(filename=model_file, model_dict=data) else: raise ValueError("JAX backend only supports converting .jax directory") @@ -93,5 +138,10 @@ def convert_str_to_int_key(item: dict): "@variables": {}, } return data + elif model_file.endswith(".hlo"): + data = load_dp_model(model_file) + data.pop("constants") + data["@variables"].pop("stablehlo") + return data else: raise ValueError("JAX backend only supports converting .jax directory") diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index df81c24ff5..dc0f280d56 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -74,14 +74,21 @@ def tearDown(self): @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() - for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): + for backend_name, suffix_idx in ( + ("tensorflow", 0), + ("pytorch", 0), + ("dpmodel", 0), + ("jax", 0), + ): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) - self.save_data_to_model(prefix + backend.suffixes[0], reference_data) - data = self.get_data_from_model(prefix + backend.suffixes[0]) + self.save_data_to_model( + prefix + backend.suffixes[suffix_idx], reference_data + ) + data = self.get_data_from_model(prefix + backend.suffixes[suffix_idx]) data = copy.deepcopy(data) reference_data = copy.deepcopy(self.data) # some keys are not expected to be not the same @@ -131,7 +138,7 @@ def test_deep_eval(self): ).reshape(1, 9) prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] - for backend_name in ("tensorflow", "pytorch", "dpmodel"): + for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue