diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 5d86472674..224fdd145c 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -34,6 +34,7 @@ ) +@BaseAtomicModel.register("linear") class LinearEnergyAtomicModel(BaseAtomicModel): """Linear model make linear combinations of several existing models. @@ -324,6 +325,7 @@ def is_aparam_nall(self) -> bool: return False +@BaseAtomicModel.register("zbl") class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel): """Model linearly combine a list of AtomicModels. diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index b9c1e93387..9f2891d8c0 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -147,6 +147,31 @@ def compute_input_stats( """Update mean and stddev for descriptor elements.""" raise NotImplementedError + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise NotImplementedError("This descriptor doesn't support compression!") + @abstractmethod def fwd( self, diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 4dc4c965fb..5bc5970a87 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -6,6 +6,7 @@ Union, ) +import array_api_compat import numpy as np from deepmd.dpmodel import ( @@ -13,6 +14,10 @@ PRECISION_DICT, NativeOP, ) +from deepmd.dpmodel.common import ( + get_xp_precision, + to_numpy_array, +) from deepmd.dpmodel.utils import ( EmbeddingNet, EnvMat, @@ -25,9 +30,6 @@ from deepmd.dpmodel.utils.update_sel import ( UpdateSel, ) -from deepmd.env import ( - GLOBAL_NP_FLOAT_PRECISION, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -122,17 +124,18 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.trainable = trainable + self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()] in_dim = 1 # not considiering type embedding - self.embeddings = NetworkCollection( + embeddings = NetworkCollection( ntypes=self.ntypes, ndim=2, network_type="embedding_network", ) for ii, embedding_idx in enumerate( - itertools.product(range(self.ntypes), repeat=self.embeddings.ndim) + itertools.product(range(self.ntypes), repeat=embeddings.ndim) ): - self.embeddings[embedding_idx] = EmbeddingNet( + embeddings[embedding_idx] = EmbeddingNet( in_dim, self.neuron, self.activation_function, @@ -140,8 +143,9 @@ def __init__( self.precision, seed=child_seed(self.seed, ii), ) + self.embeddings = embeddings self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection) - self.nnei = np.sum(self.sel) + self.nnei = sum(self.sel) self.davg = np.zeros( [self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision] ) @@ -299,20 +303,22 @@ def call( The smooth switch function. """ del mapping + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) # nf x nloc x nnei x 4 rr, diff, ww = self.env_mat.call( coord_ext, atype_ext, nlist, self.davg, self.dstd ) nf, nloc, nnei, _ = rr.shape - sec = np.append([0], np.cumsum(self.sel)) + sec = self.sel_cumsum ng = self.neuron[-1] - result = np.zeros([nf * nloc, ng], dtype=PRECISION_DICT[self.precision]) + result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision)) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) - rr = rr.reshape(nf * nloc, nnei, 4) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + rr = xp.reshape(rr, (nf * nloc, nnei, 4)) + rr = xp.astype(rr, get_xp_precision(xp, self.precision)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim @@ -325,23 +331,26 @@ def call( # nfnl x nt_i x 3 rr_i = rr[:, sec[ti] : sec[ti + 1], 1:] mm_i = exclude_mask[:, sec[ti] : sec[ti + 1]] - rr_i = rr_i * mm_i[:, :, None] + rr_i = rr_i * xp.astype(mm_i[:, :, None], rr_i.dtype) # nfnl x nt_j x 3 rr_j = rr[:, sec[tj] : sec[tj + 1], 1:] mm_j = exclude_mask[:, sec[tj] : sec[tj + 1]] - rr_j = rr_j * mm_j[:, :, None] + rr_j = rr_j * xp.astype(mm_j[:, :, None], rr_j.dtype) # nfnl x nt_i x nt_j - env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + # env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j) + env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1) # nfnl x nt_i x nt_j x 1 env_ij_reshape = env_ij[:, :, :, None] # nfnl x nt_i x nt_j x ng gg = self.embeddings[embedding_idx].call(env_ij_reshape) # nfnl x nt_i x nt_j x ng - res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + # res_ij = np.einsum("ijk,ijkm->im", env_ij, gg) + res_ij = xp.sum(env_ij[:, :, :, None] * gg, axis=(1, 2)) res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j)) result += res_ij # nf x nloc x ng - result = result.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION) + result = xp.reshape(result, (nf, nloc, ng)) + result = xp.astype(result, get_xp_precision(xp, "global")) return result, None, None, None, ww def serialize(self) -> dict: @@ -369,8 +378,8 @@ def serialize(self) -> dict: "exclude_types": self.exclude_types, "env_protection": self.env_protection, "@variables": { - "davg": self.davg, - "dstd": self.dstd, + "davg": to_numpy_array(self.davg), + "dstd": to_numpy_array(self.dstd), }, "type_map": self.type_map, "trainable": self.trainable, diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py new file mode 100644 index 0000000000..ba19785235 --- /dev/null +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + +from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.model.dp_model import ( + DPModelCommon, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +from .make_model import ( + make_model, +) + +DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel) + + +@BaseModel.register("zbl") +class DPZBLModel(DPZBLModel_): + model_type = "zbl" + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statistics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel( + train_data, type_map, local_jdata["dpmodel"] + ) + return local_jdata_cpy, min_nbor_dist diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index cccd0732cd..c29240214c 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -1,4 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.dpmodel.descriptor.se_e2_a import ( DescrptSeA, ) @@ -8,6 +17,9 @@ from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.dp_zbl_model import ( + DPZBLModel, +) from deepmd.dpmodel.model.ener_model import ( EnergyModel, ) @@ -55,6 +67,45 @@ def get_standard_model(data: dict) -> EnergyModel: ) +def get_zbl_model(data: dict) -> DPZBLModel: + data["descriptor"]["ntypes"] = len(data["type_map"]) + descriptor = BaseDescriptor(**data["descriptor"]) + fitting_type = data["fitting_net"].pop("type") + if fitting_type == "ener": + fitting = EnergyFittingNet( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + else: + raise ValueError(f"Unknown fitting type {fitting_type}") + + dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"]) + # pairtab + filepath = data["use_srtab"] + pt_model = PairTabAtomicModel( + filepath, + data["descriptor"]["rcut"], + data["descriptor"]["sel"], + type_map=data["type_map"], + ) + + rmin = data["sw_rmin"] + rmax = data["sw_rmax"] + atom_exclude_types = data.get("atom_exclude_types", []) + pair_exclude_types = data.get("pair_exclude_types", []) + return DPZBLModel( + dp_model, + pt_model, + rmin, + rmax, + type_map=data["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + ) + + def get_spin_model(data: dict) -> SpinModel: """Get a spin model from a dictionary. @@ -100,6 +151,8 @@ def get_model(data: dict): if model_type == "standard": if "spin" in data: return get_spin_model(data) + elif "use_srtab" in data: + return get_zbl_model(data) else: return get_standard_model(data) else: diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 5140a88c97..a81ddb69a6 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -572,11 +572,12 @@ def call(self, x): def clear(self): """Clear the network parameters to zero.""" for layer in self.layers: - layer.w.fill(0.0) + xp = array_api_compat.array_namespace(layer.w) + layer.w = xp.zeros_like(layer.w) if layer.b is not None: - layer.b.fill(0.0) + layer.b = xp.zeros_like(layer.b) if layer.idt is not None: - layer.idt.fill(0.0) + layer.idt = xp.zeros_like(layer.idt) return NN diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index cabee5a189..4e55bc7659 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -11,10 +11,14 @@ from deepmd.jax.descriptor.se_e2_r import ( DescrptSeR, ) +from deepmd.jax.descriptor.se_t import ( + DescrptSeT, +) __all__ = [ "DescrptSeA", "DescrptSeR", + "DescrptSeT", "DescrptDPA1", "DescrptHybrid", ] diff --git a/deepmd/jax/descriptor/se_t.py b/deepmd/jax/descriptor/se_t.py new file mode 100644 index 0000000000..029f4231fe --- /dev/null +++ b/deepmd/jax/descriptor/se_t.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.jax.common import ( + ArrayAPIVariable, + flax_module, + to_jax_array, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.jax.utils.network import ( + NetworkCollection, +) + + +@BaseDescriptor.register("se_e3") +@BaseDescriptor.register("se_at") +@BaseDescriptor.register("se_a_3be") +@flax_module +class DescrptSeT(DescrptSeTDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 2a6186ac46..d62681490c 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from deepmd.jax.common import ( ArrayAPIVariable, flax_module, @@ -51,6 +54,14 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("property") +@flax_module +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + @BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index d3156f7c84..76115b2810 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -24,9 +24,15 @@ from deepmd.pt.utils.env import ( RESERVED_PRECISON_DICT, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -261,6 +267,8 @@ def __init__( if ln_eps is None: ln_eps = 1e-5 + self.tebd_input_mode = tebd_input_mode + del type, spin, attn_mask self.se_atten = DescrptBlockSeAtten( rcut, @@ -293,6 +301,7 @@ def __init__( self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias self.type_map = type_map + self.compress = False self.type_embedding = TypeEmbedNet( ntypes, tebd_dim, @@ -551,6 +560,84 @@ def t_cvt(xx): ) return obj + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + if self.compress: + raise ValueError("Compression is already enabled.") + assert ( + not self.se_atten.resnet_dt + ), "Model compression error: descriptor resnet_dt must be false!" + for tt in self.se_atten.exclude_types: + if (tt[0] not in range(self.se_atten.ntypes)) or ( + tt[1] not in range(self.se_atten.ntypes) + ): + raise RuntimeError( + "exclude types" + + str(tt) + + " must within the number of atomic types " + + str(self.se_atten.ntypes) + + "!" + ) + if ( + self.se_atten.ntypes * self.se_atten.ntypes + - len(self.se_atten.exclude_types) + == 0 + ): + raise RuntimeError( + "Empty embedding-nets are not supported in model compression!" + ) + + if self.se_atten.attn_layer != 0: + raise RuntimeError("Cannot compress model when attention layer is not 0.") + + if self.tebd_input_mode != "strip": + raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'") + + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + + self.se_atten.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 56cb1f5bc6..630b96ce9b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -58,11 +58,34 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from .base_descriptor import ( BaseDescriptor, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"): + + def tabulate_fusion_se_a( + argument0, + argument1, + argument2, + argument3, + argument4, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_a = tabulate_fusion_se_a + @BaseDescriptor.register("se_e2_a") @BaseDescriptor.register("se_a") @@ -93,6 +116,7 @@ def __init__( raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map + self.compress = False self.sea = DescrptBlockSeA( rcut, rcut_smth, @@ -225,6 +249,53 @@ def reinit_exclude( """Update the type exclusions.""" self.sea.reinit_exclude(exclude_types) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.sea.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + def forward( self, coord_ext: torch.Tensor, @@ -366,6 +437,10 @@ def update_sel( class DescrptBlockSeA(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] def __init__( self, @@ -425,6 +500,13 @@ def __init__( self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + # add for compression + self.compress = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] + ndim = 1 if self.type_one_side else 2 filter_layers = NetworkCollection( ndim=ndim, ntypes=len(sel), network_type="embedding_network" @@ -443,6 +525,7 @@ def __init__( self.filter_layers = filter_layers self.stats = None # set trainable + self.trainable = trainable for param in self.parameters(): param.requires_grad = trainable @@ -470,6 +553,10 @@ def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3.""" + return self.filter_neuron[-1] + def get_dim_emb(self) -> int: """Returns the output dimension.""" return self.neuron[-1] @@ -578,6 +665,19 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + self.compress = True + self.table_data = table_data + self.table_config = table_config + self.lower = lower + self.upper = upper + def forward( self, nlist: torch.Tensor, @@ -627,6 +727,7 @@ def forward( for embedding_idx, ll in enumerate(self.filter_layers.networks): if self.type_one_side: ii = embedding_idx + ti = -1 # torch.jit is not happy with slice(None) # ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) # applying a mask seems to cause performance degradation @@ -648,10 +749,35 @@ def forward( rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] rr = rr * mm[:, :, None] ss = rr[:, :, :1] - # nfnl x nt x ng - gg = ll.forward(ss) - # nfnl x 4 x ng - gr = torch.matmul(rr.permute(0, 2, 1), gg) + + if self.compress: + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf + tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec) + gr = torch.ops.deepmd.tabulate_fusion_se_a( + tensor_data.contiguous(), + torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + ss.contiguous(), + rr.contiguous(), + self.filter_neuron[-1], + )[0] + else: + # nfnl x nt x ng + gg = ll.forward(ss) + # nfnl x 4 x ng + gr = torch.matmul(rr.permute(0, 2, 1), gg) + if ti_mask is not None: xyz_scatter[ti_mask] += gr else: diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index aab72f7e98..8c56ccf827 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -49,9 +49,33 @@ check_version_compatibility, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"): + + def tabulate_fusion_se_atten( + argument0, + argument1, + argument2, + argument3, + argument4, + argument5, + argument6, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_atten is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_atten = tabulate_fusion_se_atten + @DescriptorBlock.register("se_atten") class DescrptBlockSeAtten(DescriptorBlock): + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] + def __init__( self, rcut: float, @@ -178,6 +202,14 @@ def __init__( ln_eps = 1e-5 self.ln_eps = ln_eps + # add for compression + self.compress = False + self.is_sorted = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] + if isinstance(sel, int): sel = [sel] @@ -189,6 +221,7 @@ def __init__( self.ndescrpt = self.nnei * 4 # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) + self.dpa1_attention = NeighborGatedAttention( self.attn_layer, self.nnei, @@ -277,6 +310,10 @@ def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3.""" + return self.filter_neuron[-1] + def get_dim_emb(self) -> int: """Returns the output dimension of embedding.""" return self.filter_neuron[-1] @@ -384,8 +421,22 @@ def reinit_exclude( exclude_types: list[tuple[int, int]] = [], ): self.exclude_types = exclude_types + self.is_sorted = len(self.exclude_types) == 0 self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + self.compress = True + self.table_data = table_data + self.table_config = table_config + self.lower = lower + self.upper = upper + def forward( self, nlist: torch.Tensor, @@ -450,20 +501,21 @@ def forward( sw = torch.squeeze(sw, -1) # nf x nloc x nt -> nf x nloc x nnei x nt atype_tebd = extended_atype_embd[:, :nloc, :] - atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) + atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) # i # nf x nall x nt nt = extended_atype_embd.shape[-1] atype_tebd_ext = extended_atype_embd # nb x (nloc x nnei) x nt index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) # nb x (nloc x nnei) x nt - atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) + atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # j # nb x nloc x nnei x nt atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) # (nb x nloc) x nnei exclude_mask = exclude_mask.view(nb * nloc, nnei) + # nfnl x nnei x 4 dmatrix = dmatrix.view(-1, self.nnei, 4) nfnl = dmatrix.shape[0] @@ -482,33 +534,91 @@ def forward( ss = torch.concat([ss, nlist_tebd], dim=2) # nfnl x nnei x ng gg = self.filter_layers.networks[0](ss) + input_r = torch.nn.functional.normalize( + rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) elif self.tebd_input_mode in ["strip"]: - # nfnl x nnei x ng - gg_s = self.filter_layers.networks[0](ss) - assert self.filter_layers_strip is not None - if not self.type_one_side: - # nfnl x nnei x (tebd_dim * 2) - tt = torch.concat([nlist_tebd, atype_tebd], dim=2) + if self.compress: + net = "filter_net" + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + ss = ss.reshape(-1, 1) + # nfnl x nnei x ng + # gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) # dynamic, index + else: + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + # gg = gg_s * gg_t + gg_s + tensor_data = self.table_data[net].to(gg_t.device).to(dtype=self.prec) + info_tensor = torch.tensor(info, dtype=self.prec, device="cpu") + gg_t = gg_t.reshape(-1, gg_t.size(-1)) + # Convert all tensors to the required precision at once + ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t)) + xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( + tensor_data.contiguous(), + info_tensor.contiguous(), + ss.contiguous(), + rr.contiguous(), + gg_t.contiguous(), + self.filter_neuron[-1], + self.is_sorted, + )[0] + # to make torchscript happy + gg = torch.empty( + nframes, + nloc, + self.nnei, + self.filter_neuron[-1], + dtype=gg_t.dtype, + device=gg_t.device, + ) else: - # nfnl x nnei x tebd_dim - tt = nlist_tebd - # nfnl x nnei x ng - gg_t = self.filter_layers_strip.networks[0](tt) - if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) - # nfnl x nnei x ng - gg = gg_s * gg_t + gg_s + # nfnl x nnei x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) # dynamic, index + else: + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s + input_r = torch.nn.functional.normalize( + rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) else: raise NotImplementedError - input_r = torch.nn.functional.normalize( - rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - gg = self.dpa1_attention( - gg, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - # nfnl x 4 x ng - xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) rot_mat = xyz_scatter_1[:, :, 1:4] @@ -516,9 +626,12 @@ def forward( result = torch.matmul( xyz_scatter_1, xyz_scatter_2 ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] + return ( result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron), - gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]), + gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]) + if not self.compress + else None, dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:], rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3), sw, diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 0aa50c613f..4a74b7671f 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -32,9 +32,15 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -52,10 +58,31 @@ BaseDescriptor, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"): + + def tabulate_fusion_se_r( + argument0, + argument1, + argument2, + argument3, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_r is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_r = tabulate_fusion_se_r + @BaseDescriptor.register("se_e2_r") @BaseDescriptor.register("se_r") class DescrptSeR(BaseDescriptor, torch.nn.Module): + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] + def __init__( self, rcut, @@ -90,6 +117,12 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection + # add for compression + self.compress = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] self.sel = sel self.sec = torch.tensor( @@ -123,6 +156,7 @@ def __init__( self.filter_layers = filter_layers self.stats = None # set trainable + self.trainable = trainable for param in self.parameters(): param.requires_grad = trainable @@ -313,6 +347,51 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.table_data = self.table.data + self.compress = True + def forward( self, coord_ext: torch.Tensor, @@ -353,7 +432,7 @@ def forward( The smooth switch function. """ - del mapping + del mapping, comm_dict nf = nlist.shape[0] nloc = nlist.shape[1] atype = atype_ext[:, :nloc] @@ -380,19 +459,44 @@ def forward( # nfnl x nnei exclude_mask = self.emask(nlist, atype_ext).view(nfnl, self.nnei) + xyz_scatter_total = [] for ii, ll in enumerate(self.filter_layers.networks): # nfnl x nt mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 1 ss = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] ss = ss * mm[:, :, None] - # nfnl x nt x ng - gg = ll.forward(ss) - gg = torch.mean(gg, dim=1).unsqueeze(1) - xyz_scatter += gg * (self.sel[ii] / self.nnei) + if self.compress: + ss = ss.squeeze(-1) + net = "filter_-1_net_" + str(ii) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec) + xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_r( + tensor_data.contiguous(), + torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + ss, + self.filter_neuron[-1], + )[0] + xyz_scatter_total.append(xyz_scatter) + else: + # nfnl x nt x ng + gg = ll.forward(ss) + gg = torch.mean(gg, dim=1).unsqueeze(1) + xyz_scatter += gg * (self.sel[ii] / self.nnei) res_rescale = 1.0 / 5.0 - result = xyz_scatter * res_rescale + if self.compress: + xyz_scatter = torch.cat(xyz_scatter_total, dim=1) + result = torch.mean(xyz_scatter, dim=1) * res_rescale + else: + result = xyz_scatter * res_rescale result = result.view(nf, nloc, self.filter_neuron[-1]) return ( result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 7b83bcbd69..5a634d7549 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -58,11 +58,34 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from .base_descriptor import ( BaseDescriptor, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"): + + def tabulate_fusion_se_t( + argument0, + argument1, + argument2, + argument3, + argument4, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_t is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_t = tabulate_fusion_se_t + @BaseDescriptor.register("se_e3") @BaseDescriptor.register("se_at") @@ -129,6 +152,7 @@ def __init__( raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map + self.compress = False self.seat = DescrptBlockSeT( rcut, rcut_smth, @@ -252,6 +276,54 @@ def compute_input_stats( """ return self.seat.compute_input_stats(merged, path) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + exclude_types=data["exclude_types"], + activation_fn=ActivationFn(data["activation_function"]), + ) + stride_1_scaled = table_stride_1 * 10 + stride_2_scaled = table_stride_2 * 10 + self.table_config = [ + table_extrapolate, + stride_1_scaled, + stride_2_scaled, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, stride_1_scaled, stride_2_scaled + ) + self.seat.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], @@ -396,6 +468,10 @@ def update_sel( class DescrptBlockSeT(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] def __init__( self, @@ -467,6 +543,12 @@ def __init__( self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 + # add for compression + self.compress = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) @@ -628,6 +710,19 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + self.compress = True + self.table_data = table_data + self.table_config = table_config + self.lower = lower + self.upper = upper + def forward( self, nlist: torch.Tensor, @@ -711,12 +806,36 @@ def forward( rr_j = rr_j * mm_j[:, :, None] # nfnl x nt_i x nt_j env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) - # nfnl x nt_i x nt_j x 1 - env_ij_reshape = env_ij.unsqueeze(-1) - # nfnl x nt_i x nt_j x ng - gg = ll.forward(env_ij_reshape) - # nfnl x nt_i x nt_j x ng - res_ij = torch.einsum("ijk,ijkm->im", env_ij, gg) + if self.compress: + ebd_env_ij = env_ij.view(-1, 1) + net = "filter_" + str(ti) + "_net_" + str(tj) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + tensor_data = ( + self.table_data[net].to(env_ij.device).to(dtype=self.prec) + ) + ebd_env_ij = ebd_env_ij.to(dtype=self.prec) + env_ij = env_ij.to(dtype=self.prec) + res_ij = torch.ops.deepmd.tabulate_fusion_se_t( + tensor_data.contiguous(), + torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + ebd_env_ij.contiguous(), + env_ij.contiguous(), + self.filter_neuron[-1], + )[0] + else: + # nfnl x nt_i x nt_j x 1 + env_ij_reshape = env_ij.unsqueeze(-1) + # nfnl x nt_i x nt_j x ng + gg = ll.forward(env_ij_reshape) + # nfnl x nt_i x nt_j x ng + res_ij = torch.einsum("ijk,ijkm->im", env_ij, gg) res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j)) result += res_ij # xyz_scatter /= (self.nnei * self.nnei) diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index d19070fc5b..4028d77228 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -30,7 +30,7 @@ @BaseModel.register("linear_ener") class LinearEnergyModel(DPLinearModel_): - model_type = "ener" + model_type = "linear_ener" def __init__( self, diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index e1ef00f5fe..0f05e3e56d 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -30,7 +30,7 @@ @BaseModel.register("zbl") class DPZBLModel(DPZBLModel_): - model_type = "ener" + model_type = "zbl" def __init__( self, diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py new file mode 100644 index 0000000000..7394ac082d --- /dev/null +++ b/deepmd/pt/utils/tabulate.py @@ -0,0 +1,607 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from functools import ( + cached_property, +) + +import numpy as np +import torch + +import deepmd +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) +from deepmd.utils.tabulate import ( + BaseTabulate, +) + +log = logging.getLogger(__name__) + +SQRT_2_PI = np.sqrt(2 / np.pi) +GGELU = 0.044715 + + +class DPTabulate(BaseTabulate): + r"""Class for tabulation. + + Compress a model, which including tabulating the embedding-net. + The table is composed of fifth-order polynomial coefficients and is assembled from two sub-tables. The first table takes the stride(parameter) as it's uniform stride, while the second table takes 10 * stride as it's uniform stride + The range of the first table is automatically detected by deepmd-kit, while the second table ranges from the first table's upper boundary(upper) to the extrapolate(parameter) * upper. + + Parameters + ---------- + descrpt + Descriptor of the original model + neuron + Number of neurons in each hidden layers of the embedding net :math:`\\mathcal{N}` + type_one_side + Try to build N_types tables. Otherwise, building N_types^2 tables + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + activation_function + The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ActivationFn. + """ + + def __init__( + self, + descrpt, + neuron: list[int], + type_one_side: bool = False, + exclude_types: list[list[int]] = [], + activation_fn: ActivationFn = ActivationFn("tanh"), + ) -> None: + super().__init__( + descrpt, + neuron, + type_one_side, + exclude_types, + True, + ) + self.descrpt_type = self._get_descrpt_type() + + supported_descrpt_type = ( + "Atten", + "A", + "T", + "R", + ) + + if self.descrpt_type in supported_descrpt_type: + self.sel_a = self.descrpt.get_sel() + self.rcut = self.descrpt.get_rcut() + self.rcut_smth = self.descrpt.get_rcut_smth() + else: + raise RuntimeError("Unsupported descriptor") + + # functype + activation_map = { + "tanh": 1, + "gelu": 2, + "gelu_tf": 2, + "relu": 3, + "relu6": 4, + "softplus": 5, + "sigmoid": 6, + } + + activation = activation_fn.activation + if activation in activation_map: + self.functype = activation_map[activation] + else: + raise RuntimeError("Unknown activation function type!") + + self.activation_fn = activation_fn + self.davg = self.descrpt.serialize()["@variables"]["davg"] + self.dstd = self.descrpt.serialize()["@variables"]["dstd"] + self.ntypes = self.descrpt.get_ntypes() + + self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"] + + self.layer_size = self._get_layer_size() + self.table_size = self._get_table_size() + + self.bias = self._get_bias() + self.matrix = self._get_matrix() + + self.data_type = self._get_data_type() + self.last_layer_size = self._get_last_layer_size() + + def _make_data(self, xx, idx): + """Generate tabulation data for the given input. + + Parameters + ---------- + xx : np.ndarray + Input values to tabulate + idx : int + Index for accessing the correct network parameters + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + Values, first derivatives, and second derivatives + """ + xx = torch.from_numpy(xx).view(-1, 1).to(env.DEVICE) + for layer in range(self.layer_size): + if layer == 0: + xbar = torch.matmul( + xx, + torch.from_numpy(self.matrix["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ), + ) + torch.from_numpy(self.bias["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ) + if self.neuron[0] == 1: + yy = ( + self._layer_0( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + + xx + ) + dy = unaggregated_dy_dx_s( + yy - xx, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + torch.ones((1, 1), dtype=yy.dtype) # pylint: disable=no-explicit-device + dy2 = unaggregated_dy2_dx_s( + yy - xx, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + elif self.neuron[0] == 2: + tt, yy = self._layer_1( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dy = unaggregated_dy_dx_s( + yy - tt, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + torch.ones((1, 2), dtype=yy.dtype) # pylint: disable=no-explicit-device + dy2 = unaggregated_dy2_dx_s( + yy - tt, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + else: + yy = self._layer_0( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dy = unaggregated_dy_dx_s( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + dy2 = unaggregated_dy2_dx_s( + yy, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + else: + ybar = torch.matmul( + yy, + torch.from_numpy(self.matrix["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ), + ) + torch.from_numpy(self.bias["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ) + if self.neuron[layer] == self.neuron[layer - 1]: + zz = ( + self._layer_0( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + + yy + ) + dz = unaggregated_dy_dx( + zz - yy, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz - yy, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + elif self.neuron[layer] == 2 * self.neuron[layer - 1]: + tt, zz = self._layer_1( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dz = unaggregated_dy_dx( + zz - tt, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz - tt, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + else: + zz = self._layer_0( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dz = unaggregated_dy_dx( + zz, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + dy = dz + yy = zz + + vv = zz.detach().cpu().numpy().astype(self.data_type) + dd = dy.detach().cpu().numpy().astype(self.data_type) + d2 = dy2.detach().cpu().numpy().astype(self.data_type) + return vv, dd, d2 + + def _layer_0(self, x, w, b): + w = torch.from_numpy(w).to(env.DEVICE) + b = torch.from_numpy(b).to(env.DEVICE) + return self.activation_fn(torch.matmul(x, w) + b) + + def _layer_1(self, x, w, b): + w = torch.from_numpy(w).to(env.DEVICE) + b = torch.from_numpy(b).to(env.DEVICE) + t = torch.cat([x, x], dim=1) + return t, self.activation_fn(torch.matmul(x, w) + b) + t + + def _get_descrpt_type(self): + if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1): + return "Atten" + elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA): + return "A" + elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeR): + return "R" + elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeT): + return "T" + raise RuntimeError(f"Unsupported descriptor {self.descrpt}") + + def _get_layer_size(self): + # get the number of layers in EmbeddingNet + layer_size = 0 + basic_size = 0 + if self.type_one_side: + basic_size = len(self.embedding_net_nodes) * len(self.neuron) + else: + basic_size = ( + len(self.embedding_net_nodes) + * len(self.embedding_net_nodes[0]) + * len(self.neuron) + ) + if self.descrpt_type == "Atten": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + elif self.descrpt_type == "A": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + if self.type_one_side: + layer_size = basic_size // (self.ntypes - self._n_all_excluded) + elif self.descrpt_type == "T": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + # layer_size = basic_size // int(comb(self.ntypes + 1, 2)) + elif self.descrpt_type == "R": + layer_size = basic_size // ( + self.ntypes * self.ntypes - len(self.exclude_types) + ) + if self.type_one_side: + layer_size = basic_size // (self.ntypes - self._n_all_excluded) + else: + raise RuntimeError("Unsupported descriptor") + return layer_size + + def _get_network_variable(self, var_name: str) -> dict: + """Get network variables (weights or biases) for all layers. + + Parameters + ---------- + var_name : str + Name of the variable to get ('w' for weights, 'b' for biases) + + Returns + ------- + dict + Dictionary mapping layer names to their variables + """ + result = {} + for layer in range(1, self.layer_size + 1): + result["layer_" + str(layer)] = [] + if self.descrpt_type == "Atten": + node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ + var_name + ] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "A": + if self.type_one_side: + for ii in range(0, self.ntypes): + if not self._all_excluded(ii): + node = self.embedding_net_nodes[ii]["layers"][layer - 1][ + "@variables" + ][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + for ii in range(0, self.ntypes * self.ntypes): + if ( + ii // self.ntypes, + ii % self.ntypes, + ) not in self.exclude_types: + node = self.embedding_net_nodes[ + (ii % self.ntypes) * self.ntypes + ii // self.ntypes + ]["layers"][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + elif self.descrpt_type == "T": + for ii in range(self.ntypes): + for jj in range(ii, self.ntypes): + node = self.embedding_net_nodes[jj * self.ntypes + ii][ + "layers" + ][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "R": + if self.type_one_side: + for ii in range(0, self.ntypes): + if not self._all_excluded(ii): + node = self.embedding_net_nodes[ii]["layers"][layer - 1][ + "@variables" + ][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + for ii in range(0, self.ntypes * self.ntypes): + if ( + ii // self.ntypes, + ii % self.ntypes, + ) not in self.exclude_types: + node = self.embedding_net_nodes[ + (ii % self.ntypes) * self.ntypes + ii // self.ntypes + ]["layers"][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + raise RuntimeError("Unsupported descriptor") + return result + + def _get_bias(self): + return self._get_network_variable("b") + + def _get_matrix(self): + return self._get_network_variable("w") + + def _convert_numpy_to_tensor(self): + """Convert self.data from np.ndarray to torch.Tensor.""" + for ii in self.data: + self.data[ii] = torch.tensor(self.data[ii], device=env.DEVICE) # pylint: disable=no-explicit-dtype + + @cached_property + def _n_all_excluded(self) -> int: + """Then number of types excluding all types.""" + return sum(int(self._all_excluded(ii)) for ii in range(0, self.ntypes)) + + +# customized op +def grad(xbar, y, functype): # functype=tanh, gelu, .. + if functype == 1: + return 1 - y * y + elif functype == 2: + var = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + return ( + 0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1) + + 0.5 * var + + 0.5 + ) + elif functype == 3: + return 0.0 if xbar <= 0 else 1.0 + elif functype == 4: + return 0.0 if xbar <= 0 or xbar >= 6 else 1.0 + elif functype == 5: + return 1.0 - 1.0 / (1.0 + np.exp(xbar)) + elif functype == 6: + return y * (1 - y) + + raise ValueError(f"Unsupported function type: {functype}") + + +def grad_grad(xbar, y, functype): + if functype == 1: + return -2 * y * (1 - y * y) + elif functype == 2: + var1 = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1) + return ( + 3 * GGELU * SQRT_2_PI * xbar**2 * (1 - var1**2) + - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar**2 + 1) * var1 + + var2 + ) + elif functype in [3, 4]: + return 0 + elif functype == 5: + return np.exp(xbar) / ((1 + np.exp(xbar)) * (1 + np.exp(xbar))) + elif functype == 6: + return y * (1 - y) * (1 - 2 * y) + else: + return -1 + + +def unaggregated_dy_dx_s( + y: torch.Tensor, w_np: np.ndarray, xbar: torch.Tensor, functype: int +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if y.dim() != 2: + raise ValueError("Dim of input y should be 2") + if w.dim() != 2: + raise ValueError("Dim of input w should be 2") + if xbar.dim() != 2: + raise ValueError("Dim of input xbar should be 2") + + length, width = y.shape + dy_dx = torch.zeros_like(y) + w = torch.flatten(w) + + for ii in range(length): + for jj in range(width): + dy_dx[ii, jj] = grad(xbar[ii, jj], y[ii, jj], functype) * w[jj] + + return dy_dx + + +def unaggregated_dy2_dx_s( + y: torch.Tensor, + dy: torch.Tensor, + w_np: np.ndarray, + xbar: torch.Tensor, + functype: int, +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if y.dim() != 2: + raise ValueError("Dim of input y should be 2") + if dy.dim() != 2: + raise ValueError("Dim of input dy should be 2") + if w.dim() != 2: + raise ValueError("Dim of input w should be 2") + if xbar.dim() != 2: + raise ValueError("Dim of input xbar should be 2") + + length, width = y.shape + dy2_dx = torch.zeros_like(y) + w = torch.flatten(w) + + for ii in range(length): + for jj in range(width): + dy2_dx[ii, jj] = ( + grad_grad(xbar[ii, jj], y[ii, jj], functype) * w[jj] * w[jj] + ) + + return dy2_dx + + +def unaggregated_dy_dx( + z: torch.Tensor, + w_np: np.ndarray, + dy_dx: torch.Tensor, + ybar: torch.Tensor, + functype: int, +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if z.dim() != 2: + raise ValueError("z tensor must have 2 dimensions") + if w.dim() != 2: + raise ValueError("w tensor must have 2 dimensions") + if dy_dx.dim() != 2: + raise ValueError("dy_dx tensor must have 2 dimensions") + if ybar.dim() != 2: + raise ValueError("ybar tensor must have 2 dimensions") + + length, width = z.shape + size = w.shape[0] + dy_dx = torch.flatten(dy_dx) + + dz_dx = torch.zeros_like(z) + + for kk in range(length): + for ii in range(width): + dz_drou = grad(ybar[kk, ii], z[kk, ii], functype) + accumulator = 0.0 + for jj in range(size): + accumulator += w[jj, ii] * dy_dx[kk * size + jj] + dz_drou *= accumulator + if width == 2 * size or width == size: + dz_drou += dy_dx[kk * size + ii % size] + dz_dx[kk, ii] = dz_drou + + return dz_dx + + +def unaggregated_dy2_dx( + z: torch.Tensor, + w_np: np.ndarray, + dy_dx: torch.Tensor, + dy2_dx: torch.Tensor, + ybar: torch.Tensor, + functype: int, +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if z.dim() != 2: + raise ValueError("z tensor must have 2 dimensions") + if w.dim() != 2: + raise ValueError("w tensor must have 2 dimensions") + if dy_dx.dim() != 2: + raise ValueError("dy_dx tensor must have 2 dimensions") + if dy2_dx.dim() != 2: + raise ValueError("dy2_dx tensor must have 2 dimensions") + if ybar.dim() != 2: + raise ValueError("ybar tensor must have 2 dimensions") + + length, width = z.shape + size = w.shape[0] + dy_dx = torch.flatten(dy_dx) + dy2_dx = torch.flatten(dy2_dx) + + dz2_dx = torch.zeros_like(z) + + for kk in range(length): + for ii in range(width): + dz_drou = grad(ybar[kk, ii], z[kk, ii], functype) + accumulator1 = 0.0 + for jj in range(size): + accumulator1 += w[jj, ii] * dy2_dx[kk * size + jj] + dz_drou *= accumulator1 + accumulator2 = 0.0 + for jj in range(size): + accumulator2 += w[jj, ii] * dy_dx[kk * size + jj] + dz_drou += ( + grad_grad(ybar[kk, ii], z[kk, ii], functype) + * accumulator2 + * accumulator2 + ) + if width == 2 * size or width == size: + dz_drou += dy2_dx[kk * size + ii % size] + dz2_dx[kk, ii] = dz_drou + + return dz2_dx diff --git a/deepmd/tf/utils/tabulate.py b/deepmd/tf/utils/tabulate.py index 588ebdd55e..30171b12db 100644 --- a/deepmd/tf/utils/tabulate.py +++ b/deepmd/tf/utils/tabulate.py @@ -2,7 +2,6 @@ import logging from functools import ( cached_property, - lru_cache, ) from typing import ( Callable, @@ -28,11 +27,14 @@ get_embedding_net_nodes_from_graph_def, get_tensor_by_name_from_graph, ) +from deepmd.utils.tabulate import ( + BaseTabulate, +) log = logging.getLogger(__name__) -class DPTabulate: +class DPTabulate(BaseTabulate): r"""Class for tabulation. Compress a model, which including tabulating the embedding-net. @@ -71,13 +73,18 @@ def __init__( activation_fn: Callable[[tf.Tensor], tf.Tensor] = tf.nn.tanh, suffix: str = "", ) -> None: + super().__init__( + descrpt, + neuron, + type_one_side, + exclude_types, + False, + ) + + self.descrpt_type = self._get_descrpt_type() """Constructor.""" - self.descrpt = descrpt - self.neuron = neuron self.graph = graph self.graph_def = graph_def - self.type_one_side = type_one_side - self.exclude_types = exclude_types self.suffix = suffix # functype @@ -156,271 +163,25 @@ def __init__( self.upper = {} self.lower = {} - def build( - self, min_nbor_dist: float, extrapolate: float, stride0: float, stride1: float - ) -> tuple[dict[str, int], dict[str, int]]: - r"""Build the tables for model compression. - - Parameters - ---------- - min_nbor_dist - The nearest distance between neighbor atoms - extrapolate - The scale of model extrapolation - stride0 - The uniform stride of the first table - stride1 - The uniform stride of the second table - - Returns - ------- - lower : dict[str, int] - The lower boundary of environment matrix by net - upper : dict[str, int] - The upper boundary of environment matrix by net - """ - # tabulate range [lower, upper] with stride0 'stride0' - lower, upper = self._get_env_mat_range(min_nbor_dist) - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten) or isinstance( - self.descrpt, deepmd.tf.descriptor.DescrptSeAEbdV2 - ): - uu = np.max(upper) - ll = np.min(lower) - xx = np.arange(ll, uu, stride0, dtype=self.data_type) - xx = np.append( - xx, - np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), - ) - xx = np.append(xx, np.array([extrapolate * uu], dtype=self.data_type)) - nspline = ((uu - ll) / stride0 + (extrapolate * uu - uu) / stride1).astype( - int - ) - self._build_lower( - "filter_net", xx, 0, uu, ll, stride0, stride1, extrapolate, nspline - ) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - for ii in range(self.table_size): - if (self.type_one_side and not self._all_excluded(ii)) or ( - not self.type_one_side - and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types - ): - if self.type_one_side: - net = "filter_-1_net_" + str(ii) - # upper and lower should consider all types which are not excluded and sel>0 - idx = [ - (type_i, ii) not in self.exclude_types - and self.sel_a[type_i] > 0 - for type_i in range(self.ntypes) - ] - uu = np.max(upper[idx]) - ll = np.min(lower[idx]) - else: - ielement = ii // self.ntypes - net = ( - "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) - ) - uu = upper[ielement] - ll = lower[ielement] - xx = np.arange(ll, uu, stride0, dtype=self.data_type) - xx = np.append( - xx, - np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), - ) - xx = np.append( - xx, np.array([extrapolate * uu], dtype=self.data_type) - ) - nspline = ( - (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 - ).astype(int) - self._build_lower( - net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline - ) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - xx_all = [] - for ii in range(self.ntypes): - xx = np.arange( - extrapolate * lower[ii], lower[ii], stride1, dtype=self.data_type - ) - xx = np.append( - xx, np.arange(lower[ii], upper[ii], stride0, dtype=self.data_type) - ) - xx = np.append( - xx, - np.arange( - upper[ii], - extrapolate * upper[ii], - stride1, - dtype=self.data_type, - ), - ) - xx = np.append( - xx, np.array([extrapolate * upper[ii]], dtype=self.data_type) - ) - xx_all.append(xx) - nspline = ( - (upper - lower) / stride0 - + 2 * ((extrapolate * upper - upper) / stride1) - ).astype(int) - idx = 0 - for ii in range(self.ntypes): - for jj in range(ii, self.ntypes): - net = "filter_" + str(ii) + "_net_" + str(jj) - self._build_lower( - net, - xx_all[ii], - idx, - upper[ii], - lower[ii], - stride0, - stride1, - extrapolate, - nspline[ii], - ) - idx += 1 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - for ii in range(self.table_size): - if (self.type_one_side and not self._all_excluded(ii)) or ( - not self.type_one_side - and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types - ): - if self.type_one_side: - net = "filter_-1_net_" + str(ii) - # upper and lower should consider all types which are not excluded and sel>0 - idx = [ - (type_i, ii) not in self.exclude_types - and self.sel_a[type_i] > 0 - for type_i in range(self.ntypes) - ] - uu = np.max(upper[idx]) - ll = np.min(lower[idx]) - else: - ielement = ii // self.ntypes - net = ( - "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) - ) - uu = upper[ielement] - ll = lower[ielement] - xx = np.arange(ll, uu, stride0, dtype=self.data_type) - xx = np.append( - xx, - np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), - ) - xx = np.append( - xx, np.array([extrapolate * uu], dtype=self.data_type) - ) - nspline = ( - (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 - ).astype(int) - self._build_lower( - net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline - ) - else: - raise RuntimeError("Unsupported descriptor") - self._convert_numpy_to_tensor() - - return self.lower, self.upper - - def _build_lower( - self, net, xx, idx, upper, lower, stride0, stride1, extrapolate, nspline - ): - vv, dd, d2 = self._make_data(xx, idx) - self.data[net] = np.zeros( - [nspline, 6 * self.last_layer_size], dtype=self.data_type - ) - - # tt.shape: [nspline, self.last_layer_size] - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - tt[: int((upper - lower) / stride0), :] = stride0 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - tt[ - int((lower - extrapolate * lower) / stride1) + 1 : ( - int((lower - extrapolate * lower) / stride1) - + int((upper - lower) / stride0) - ), - :, - ] = stride0 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - tt[: int((upper - lower) / stride0), :] = stride0 - else: - raise RuntimeError("Unsupported descriptor") - - # hh.shape: [nspline, self.last_layer_size] - hh = ( - vv[1 : nspline + 1, : self.last_layer_size] - - vv[:nspline, : self.last_layer_size] - ) - - self.data[net][:, : 6 * self.last_layer_size : 6] = vv[ - :nspline, : self.last_layer_size - ] - self.data[net][:, 1 : 6 * self.last_layer_size : 6] = dd[ - :nspline, : self.last_layer_size - ] - self.data[net][:, 2 : 6 * self.last_layer_size : 6] = ( - 0.5 * d2[:nspline, : self.last_layer_size] - ) - self.data[net][:, 3 : 6 * self.last_layer_size : 6] = ( - 1 / (2 * tt * tt * tt) - ) * ( - 20 * hh - - ( - 8 * dd[1 : nspline + 1, : self.last_layer_size] - + 12 * dd[:nspline, : self.last_layer_size] - ) - * tt - - ( - 3 * d2[:nspline, : self.last_layer_size] - - d2[1 : nspline + 1, : self.last_layer_size] - ) - * tt - * tt - ) - self.data[net][:, 4 : 6 * self.last_layer_size : 6] = ( - 1 / (2 * tt * tt * tt * tt) - ) * ( - -30 * hh - + ( - 14 * dd[1 : nspline + 1, : self.last_layer_size] - + 16 * dd[:nspline, : self.last_layer_size] - ) - * tt - + ( - 3 * d2[:nspline, : self.last_layer_size] - - 2 * d2[1 : nspline + 1, : self.last_layer_size] - ) - * tt - * tt - ) - self.data[net][:, 5 : 6 * self.last_layer_size : 6] = ( - 1 / (2 * tt * tt * tt * tt * tt) - ) * ( - 12 * hh - - 6 - * ( - dd[1 : nspline + 1, : self.last_layer_size] - + dd[:nspline, : self.last_layer_size] - ) - * tt - + ( - d2[1 : nspline + 1, : self.last_layer_size] - - d2[:nspline, : self.last_layer_size] - ) - * tt - * tt - ) - - self.upper[net] = upper - self.lower[net] = lower - def _load_sub_graph(self): sub_graph_def = tf.GraphDef() with tf.Graph().as_default() as sub_graph: tf.import_graph_def(sub_graph_def, name="") return sub_graph, sub_graph_def + def _get_descrpt_type(self): + if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten): + return "Atten" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAEbdV2): + return "AEbdV2" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): + return "A" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): + return "T" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): + return "R" + raise RuntimeError(f"Unsupported descriptor {self.descrpt}") + def _get_bias(self): bias = {} for layer in range(1, self.layer_size + 1): @@ -711,36 +472,6 @@ def _layer_1(self, x, w, b): t = tf.concat([x, x], axis=1) return t, self.activation_fn(tf.matmul(x, w) + b) + t - # Change the embedding net range to sw / min_nbor_dist - def _get_env_mat_range(self, min_nbor_dist): - sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut) - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - lower = -self.davg[:, 0] / self.dstd[:, 0] - upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - var = np.square(sw / (min_nbor_dist * self.dstd[:, 1:4])) - lower = np.min(-var, axis=1) - upper = np.max(var, axis=1) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - lower = -self.davg[:, 0] / self.dstd[:, 0] - upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] - else: - raise RuntimeError("Unsupported descriptor") - log.info("training data with lower boundary: " + str(lower)) - log.info("training data with upper boundary: " + str(upper)) - # returns element-wise lower and upper - return np.floor(lower), np.ceil(upper) - - def _spline5_switch(self, xx, rmin, rmax): - if xx < rmin: - vv = 1 - elif xx < rmax: - uu = (xx - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 - else: - vv = 0 - return vv - def _get_layer_size(self): layer_size = 0 if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten) or isinstance( @@ -776,54 +507,6 @@ def _n_all_excluded(self) -> int: """Then number of types excluding all types.""" return sum(int(self._all_excluded(ii)) for ii in range(0, self.ntypes)) - @lru_cache - def _all_excluded(self, ii: int) -> bool: - """Check if type ii excluds all types. - - Parameters - ---------- - ii : int - type index - - Returns - ------- - bool - if type ii excluds all types - """ - return all((ii, type_i) in self.exclude_types for type_i in range(self.ntypes)) - - def _get_table_size(self): - table_size = 0 - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten) or isinstance( - self.descrpt, deepmd.tf.descriptor.DescrptSeAEbdV2 - ): - table_size = 1 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - table_size = self.ntypes * self.ntypes - if self.type_one_side: - table_size = self.ntypes - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - table_size = int(comb(self.ntypes + 1, 2)) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - table_size = self.ntypes * self.ntypes - if self.type_one_side: - table_size = self.ntypes - else: - raise RuntimeError("Unsupported descriptor") - return table_size - - def _get_data_type(self): - for item in self.matrix["layer_" + str(self.layer_size)]: - if len(item) != 0: - return type(item[0][0]) - return None - - def _get_last_layer_size(self): - for item in self.matrix["layer_" + str(self.layer_size)]: - if len(item) != 0: - return item.shape[1] - return 0 - def _convert_numpy_to_tensor(self): """Convert self.data from np.ndarray to tf.Tensor.""" for ii in self.data: diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py new file mode 100644 index 0000000000..545b265b88 --- /dev/null +++ b/deepmd/utils/tabulate.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from abc import ( + ABC, + abstractmethod, +) +from functools import ( + lru_cache, +) + +import numpy as np +from scipy.special import ( + comb, +) + +log = logging.getLogger(__name__) + + +class BaseTabulate(ABC): + """A base class for pt and tf tabulation.""" + + def __init__( + self, + descrpt, + neuron, + type_one_side, + exclude_types, + is_pt, + ) -> None: + """Constructor.""" + super().__init__() + + """Shared attributes.""" + self.descrpt = descrpt + self.neuron = neuron + self.type_one_side = type_one_side + self.exclude_types = exclude_types + self.is_pt = is_pt + + """Need to be initialized in the subclass.""" + self.descrpt_type = "Base" + + self.sel_a = [] + self.rcut = 0.0 + self.rcut_smth = 0.0 + + self.davg = np.array([]) + self.dstd = np.array([]) + self.ntypes = 0 + + self.layer_size = 0 + self.table_size = 0 + + self.bias = {} + self.matrix = {} + + self.data_type = None + self.last_layer_size = 0 + + """Save the tabulation result.""" + self.data = {} + + self.upper = {} + self.lower = {} + + def build( + self, min_nbor_dist: float, extrapolate: float, stride0: float, stride1: float + ) -> tuple[dict[str, int], dict[str, int]]: + r"""Build the tables for model compression. + + Parameters + ---------- + min_nbor_dist + The nearest distance between neighbor atoms + extrapolate + The scale of model extrapolation + stride0 + The uniform stride of the first table + stride1 + The uniform stride of the second table + + Returns + ------- + lower : dict[str, int] + The lower boundary of environment matrix by net + upper : dict[str, int] + The upper boundary of environment matrix by net + """ + # tabulate range [lower, upper] with stride0 'stride0' + lower, upper = self._get_env_mat_range(min_nbor_dist) + if self.descrpt_type in ("Atten", "AEbdV2"): + uu = np.max(upper) + ll = np.min(lower) + xx = np.arange(ll, uu, stride0, dtype=self.data_type) + xx = np.append( + xx, + np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), + ) + xx = np.append(xx, np.array([extrapolate * uu], dtype=self.data_type)) + nspline = ((uu - ll) / stride0 + (extrapolate * uu - uu) / stride1).astype( + int + ) + self._build_lower( + "filter_net", xx, 0, uu, ll, stride0, stride1, extrapolate, nspline + ) + elif self.descrpt_type == "A": + for ii in range(self.table_size): + if (self.type_one_side and not self._all_excluded(ii)) or ( + not self.type_one_side + and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types + ): + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + # upper and lower should consider all types which are not excluded and sel>0 + idx = [ + (type_i, ii) not in self.exclude_types + and self.sel_a[type_i] > 0 + for type_i in range(self.ntypes) + ] + uu = np.max(upper[idx]) + ll = np.min(lower[idx]) + else: + ielement = ii // self.ntypes + net = ( + "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) + ) + if self.is_pt: + uu = np.max(upper[ielement]) + ll = np.min(lower[ielement]) + else: + uu = upper[ielement] + ll = lower[ielement] + xx = np.arange(ll, uu, stride0, dtype=self.data_type) + xx = np.append( + xx, + np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), + ) + xx = np.append( + xx, np.array([extrapolate * uu], dtype=self.data_type) + ) + nspline = ( + (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 + ).astype(int) + self._build_lower( + net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline + ) + elif self.descrpt_type == "T": + xx_all = [] + for ii in range(self.ntypes): + """Pt and tf is different here. Pt version is a two-dimensional array.""" + if self.is_pt: + uu = np.max(upper[ii]) + ll = np.min(lower[ii]) + else: + ll = lower[ii] + uu = upper[ii] + xx = np.arange(extrapolate * ll, ll, stride1, dtype=self.data_type) + xx = np.append(xx, np.arange(ll, uu, stride0, dtype=self.data_type)) + xx = np.append( + xx, + np.arange( + uu, + extrapolate * uu, + stride1, + dtype=self.data_type, + ), + ) + xx = np.append(xx, np.array([extrapolate * uu], dtype=self.data_type)) + xx_all.append(xx) + nspline = ( + (upper - lower) / stride0 + + 2 * ((extrapolate * upper - upper) / stride1) + ).astype(int) + idx = 0 + for ii in range(self.ntypes): + if self.is_pt: + uu = np.max(upper[ii]) + ll = np.min(lower[ii]) + else: + ll = lower[ii] + uu = upper[ii] + for jj in range(ii, self.ntypes): + net = "filter_" + str(ii) + "_net_" + str(jj) + self._build_lower( + net, + xx_all[ii], + idx, + uu, + ll, + stride0, + stride1, + extrapolate, + nspline[ii][0] if self.is_pt else nspline[ii], + ) + idx += 1 + elif self.descrpt_type == "R": + for ii in range(self.table_size): + if (self.type_one_side and not self._all_excluded(ii)) or ( + not self.type_one_side + and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types + ): + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + # upper and lower should consider all types which are not excluded and sel>0 + idx = [ + (type_i, ii) not in self.exclude_types + and self.sel_a[type_i] > 0 + for type_i in range(self.ntypes) + ] + uu = np.max(upper[idx]) + ll = np.min(lower[idx]) + else: + ielement = ii // self.ntypes + net = ( + "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) + ) + uu = upper[ielement] + ll = lower[ielement] + xx = np.arange(ll, uu, stride0, dtype=self.data_type) + xx = np.append( + xx, + np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), + ) + xx = np.append( + xx, np.array([extrapolate * uu], dtype=self.data_type) + ) + nspline = ( + (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 + ).astype(int) + self._build_lower( + net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline + ) + else: + raise RuntimeError("Unsupported descriptor") + + self._convert_numpy_to_tensor() + if self.is_pt: + self._convert_numpy_float_to_int() + return self.lower, self.upper + + def _build_lower( + self, net, xx, idx, upper, lower, stride0, stride1, extrapolate, nspline + ): + vv, dd, d2 = self._make_data(xx, idx) + self.data[net] = np.zeros( + [nspline, 6 * self.last_layer_size], dtype=self.data_type + ) + + # tt.shape: [nspline, self.last_layer_size] + if self.descrpt_type in ("Atten", "A", "AEbdV2"): + tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype + tt[: int((upper - lower) / stride0), :] = stride0 + elif self.descrpt_type == "T": + tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype + tt[ + int((lower - extrapolate * lower) / stride1) + 1 : ( + int((lower - extrapolate * lower) / stride1) + + int((upper - lower) / stride0) + ), + :, + ] = stride0 + elif self.descrpt_type == "R": + tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype + tt[: int((upper - lower) / stride0), :] = stride0 + else: + raise RuntimeError("Unsupported descriptor") + + # hh.shape: [nspline, self.last_layer_size] + hh = ( + vv[1 : nspline + 1, : self.last_layer_size] + - vv[:nspline, : self.last_layer_size] + ) + + self.data[net][:, : 6 * self.last_layer_size : 6] = vv[ + :nspline, : self.last_layer_size + ] + self.data[net][:, 1 : 6 * self.last_layer_size : 6] = dd[ + :nspline, : self.last_layer_size + ] + self.data[net][:, 2 : 6 * self.last_layer_size : 6] = ( + 0.5 * d2[:nspline, : self.last_layer_size] + ) + self.data[net][:, 3 : 6 * self.last_layer_size : 6] = ( + 1 / (2 * tt * tt * tt) + ) * ( + 20 * hh + - ( + 8 * dd[1 : nspline + 1, : self.last_layer_size] + + 12 * dd[:nspline, : self.last_layer_size] + ) + * tt + - ( + 3 * d2[:nspline, : self.last_layer_size] + - d2[1 : nspline + 1, : self.last_layer_size] + ) + * tt + * tt + ) + self.data[net][:, 4 : 6 * self.last_layer_size : 6] = ( + 1 / (2 * tt * tt * tt * tt) + ) * ( + -30 * hh + + ( + 14 * dd[1 : nspline + 1, : self.last_layer_size] + + 16 * dd[:nspline, : self.last_layer_size] + ) + * tt + + ( + 3 * d2[:nspline, : self.last_layer_size] + - 2 * d2[1 : nspline + 1, : self.last_layer_size] + ) + * tt + * tt + ) + self.data[net][:, 5 : 6 * self.last_layer_size : 6] = ( + 1 / (2 * tt * tt * tt * tt * tt) + ) * ( + 12 * hh + - 6 + * ( + dd[1 : nspline + 1, : self.last_layer_size] + + dd[:nspline, : self.last_layer_size] + ) + * tt + + ( + d2[1 : nspline + 1, : self.last_layer_size] + - d2[:nspline, : self.last_layer_size] + ) + * tt + * tt + ) + + self.upper[net] = upper + self.lower[net] = lower + + @abstractmethod + def _make_data(self, xx, idx) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Generate tabulation data for the given input. + + Parameters + ---------- + xx : np.ndarray + Input values to tabulate + idx : int + Index for accessing the correct network parameters + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + Values, first derivatives, and second derivatives + """ + pass + + @lru_cache + def _all_excluded(self, ii: int) -> bool: + """Check if type ii excluds all types. + + Parameters + ---------- + ii : int + type index + + Returns + ------- + bool + if type ii excluds all types + """ + return all((ii, type_i) in self.exclude_types for type_i in range(self.ntypes)) + + @abstractmethod + def _get_descrpt_type(self): + """Get the descrpt type.""" + pass + + @abstractmethod + def _get_layer_size(self): + """Get the number of embedding layer.""" + pass + + def _get_table_size(self): + table_size = 0 + if self.descrpt_type in ("Atten", "AEbdV2"): + table_size = 1 + elif self.descrpt_type == "A": + table_size = self.ntypes * self.ntypes + if self.type_one_side: + table_size = self.ntypes + elif self.descrpt_type == "T": + table_size = int(comb(self.ntypes + 1, 2)) + elif self.descrpt_type == "R": + table_size = self.ntypes * self.ntypes + if self.type_one_side: + table_size = self.ntypes + else: + raise RuntimeError("Unsupported descriptor") + return table_size + + def _get_data_type(self): + for item in self.matrix["layer_" + str(self.layer_size)]: + if len(item) != 0: + return type(item[0][0]) + return None + + def _get_last_layer_size(self): + for item in self.matrix["layer_" + str(self.layer_size)]: + if len(item) != 0: + return item.shape[1] + return 0 + + @abstractmethod + def _get_bias(self): + """Get bias of embedding net.""" + pass + + @abstractmethod + def _get_matrix(self): + """Get weight matrx of embedding net.""" + pass + + @abstractmethod + def _convert_numpy_to_tensor(self): + """Convert self.data from np.ndarray to torch.Tensor.""" + pass + + def _convert_numpy_float_to_int(self): + """Convert self.lower and self.upper from np.float32 or np.float64 to int.""" + self.lower = {k: int(v) for k, v in self.lower.items()} + self.upper = {k: int(v) for k, v in self.upper.items()} + + def _get_env_mat_range(self, min_nbor_dist): + """Change the embedding net range to sw / min_nbor_dist.""" + sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut) + if self.descrpt_type in ("Atten", "A", "AEbdV2"): + lower = -self.davg[:, 0] / self.dstd[:, 0] + upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] + elif self.descrpt_type == "T": + var = np.square(sw / (min_nbor_dist * self.dstd[:, 1:4])) + lower = np.min(-var, axis=1) + upper = np.max(var, axis=1) + elif self.descrpt_type == "R": + lower = -self.davg[:, 0] / self.dstd[:, 0] + upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] + else: + raise RuntimeError("Unsupported descriptor") + log.info("training data with lower boundary: " + str(lower)) + log.info("training data with upper boundary: " + str(upper)) + # returns element-wise lower and upper + return np.floor(lower), np.ceil(upper) + + def _spline5_switch(self, xx, rmin, rmax): + if xx < rmin: + vv = 1 + elif xx < rmax: + uu = (xx - rmin) / (rmax - rmin) + vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 + else: + vv = 0 + return vv diff --git a/doc/model/dplr.md b/doc/model/dplr.md index 91c2251346..cf071d4029 100644 --- a/doc/model/dplr.md +++ b/doc/model/dplr.md @@ -198,7 +198,7 @@ fix ID group-ID style_name keyword value ... - three or more keyword/value pairs may be appended ``` -keyword = *model* or *type_associate* or *bond_type* or *efield* +keyword = *model* or *type_associate* or *bond_type* or *efield* or *pair_deepmd_index* *model* value = name name = name of DPLR model file (e.g. frozen_model.pb) (not DW model) *type_associate* values = NR1 NW1 NR2 NW2 ... @@ -208,6 +208,8 @@ keyword = *model* or *type_associate* or *bond_type* or *efield* NBi = bond type of i-th (real atom, Wannier centroid) pair *efield* (optional) values = Ex Ey Ez Ex/Ey/Ez = electric field along x/y/z direction + *pair_deepmd_index* (optional) values = idx + idx = The index of pair_style deepmd, starting from 1, if more than one is used ``` **Examples** @@ -223,6 +225,8 @@ fix_modify 0 virial yes ``` The fix command `dplr` calculates the position of WCs by the DW model and back-propagates the long-range interaction on virtual atoms to real toms. +The fix command must be used after [pair_style `deepmd`](../third-party/lammps-command.md#pair_style-deepmd). +If there are more than 1 pair_style `deepmd`, `pair_deepmd_index` (starting from 1) must be set to assign the index of the pair_style `deepmd`. The atom names specified in [pair_style `deepmd`](../third-party/lammps-command.md#pair_style-deepmd) will be used to determine elements. If it is not set, the training parameter {ref}`type_map ` will be mapped to LAMMPS atom types. diff --git a/doc/model/train-se-e3.md b/doc/model/train-se-e3.md index 3d82c42c9e..714d75259a 100644 --- a/doc/model/train-se-e3.md +++ b/doc/model/train-se-e3.md @@ -1,7 +1,7 @@ -# Descriptor `"se_e3"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }} +# Descriptor `"se_e3"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }} ::: The notation of `se_e3` is short for three-body embedding DeepPot-SE, which incorporates embedded bond-angle information. diff --git a/source/lmp/fix_dplr.cpp b/source/lmp/fix_dplr.cpp index 8a6be7d840..34fd2515ed 100644 --- a/source/lmp/fix_dplr.cpp +++ b/source/lmp/fix_dplr.cpp @@ -62,6 +62,7 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) size_vector = 3; qe2f = force->qe2f; xstyle = ystyle = zstyle = NONE; + pair_deepmd_index = 0; if (strcmp(update->unit_style, "lj") == 0) { error->all(FLERR, @@ -125,6 +126,12 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) } sort(bond_type.begin(), bond_type.end()); iarg = iend; + } else if (string(arg[iarg]) == string("pair_deepmd_index")) { + if (iarg + 1 >= narg) { + error->all(FLERR, "Illegal pair_deepmd_index, not provided"); + } + pair_deepmd_index = atoi(arg[iarg + 1]); + iarg += 2; } else { break; } @@ -141,7 +148,7 @@ FixDPLR::FixDPLR(LAMMPS *lmp, int narg, char **arg) error->one(FLERR, e.what()); } - pair_deepmd = (PairDeepMD *)force->pair_match("deepmd", 1); + pair_deepmd = (PairDeepMD *)force->pair_match("deepmd", 1, pair_deepmd_index); if (!pair_deepmd) { error->all(FLERR, "pair_style deepmd should be set before this fix\n"); } diff --git a/source/lmp/fix_dplr.h b/source/lmp/fix_dplr.h index a6822fe4fe..c43296e611 100644 --- a/source/lmp/fix_dplr.h +++ b/source/lmp/fix_dplr.h @@ -80,6 +80,9 @@ class FixDPLR : public Fix { void update_efield_variables(); enum { NONE, CONSTANT, EQUAL }; std::vector type_idx_map; + /* The index of deepmd pair index, which starts from 1. By default 0, which + * works only when there is one deepmd pair. */ + int pair_deepmd_index; }; } // namespace LAMMPS_NS diff --git a/source/op/pt/tabulate_multi_device.cc b/source/op/pt/tabulate_multi_device.cc index bdc6f63f94..5c710f5c37 100644 --- a/source/op/pt/tabulate_multi_device.cc +++ b/source/op/pt/tabulate_multi_device.cc @@ -905,7 +905,7 @@ class TabulateFusionSeROp std::vector tabulate_fusion_se_a( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, int64_t last_layer_size) { @@ -915,7 +915,7 @@ std::vector tabulate_fusion_se_a( std::vector tabulate_fusion_se_atten( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, const torch::Tensor& two_embed_tensor, @@ -928,7 +928,7 @@ std::vector tabulate_fusion_se_atten( std::vector tabulate_fusion_se_t( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, int64_t last_layer_size) { @@ -938,7 +938,7 @@ std::vector tabulate_fusion_se_t( std::vector tabulate_fusion_se_r( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_tensor, int64_t last_layer_size) { return TabulateFusionSeROp::apply(table_tensor, table_info_tensor, em_tensor, diff --git a/source/tests/array_api_strict/descriptor/se_t.py b/source/tests/array_api_strict/descriptor/se_t.py new file mode 100644 index 0000000000..13e650aa17 --- /dev/null +++ b/source/tests/array_api_strict/descriptor/se_t.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP + +from ..common import ( + to_array_api_strict_array, +) +from ..utils.exclude_mask import ( + PairExcludeMask, +) +from ..utils.network import ( + NetworkCollection, +) + + +class DescrptSeT(DescrptSeTDP): + def __setattr__(self, name: str, value: Any) -> None: + if name in {"dstd", "davg"}: + value = to_array_api_strict_array(value) + elif name in {"embeddings"}: + if value is not None: + value = NetworkCollection.deserialize(value.serialize()) + elif name == "env_mat": + # env_mat doesn't store any value + pass + elif name == "emask": + value = PairExcludeMask(value.ntypes, value.exclude_types) + + return super().__setattr__(name, value) diff --git a/source/tests/array_api_strict/fitting/fitting.py b/source/tests/array_api_strict/fitting/fitting.py index 5a2bd9c58f..323a49cfe8 100644 --- a/source/tests/array_api_strict/fitting/fitting.py +++ b/source/tests/array_api_strict/fitting/fitting.py @@ -9,6 +9,9 @@ from deepmd.dpmodel.fitting.polarizability_fitting import ( PolarFitting as PolarFittingNetDP, ) +from deepmd.dpmodel.fitting.property_fitting import ( + PropertyFittingNet as PropertyFittingNetDP, +) from ..common import ( to_array_api_strict_array, @@ -43,6 +46,12 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +class PropertyFittingNet(PropertyFittingNetDP): + def __setattr__(self, name: str, value: Any) -> None: + value = setattr_for_general_fitting(name, value) + return super().__setattr__(name, value) + + class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: value = setattr_for_general_fitting(name, value) diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index bcad7c4502..734486becb 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -75,7 +75,7 @@ class CommonTest(ABC): data: ClassVar[dict] """Arguments data.""" - addtional_data: ClassVar[dict] = {} + additional_data: ClassVar[dict] = {} """Additional data that will not be checked.""" tf_class: ClassVar[Optional[type]] """TensorFlow model class.""" @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any: def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) @abstractmethod def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 833b76f6e1..1e6110705a 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -12,6 +12,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -29,6 +31,14 @@ from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF else: DescrptSeTTF = None +if INSTALLED_JAX: + from deepmd.jax.descriptor.se_t import DescrptSeT as DescrptSeTJAX +else: + DescrptSeTJAX = None +if INSTALLED_ARRAY_API_STRICT: + from ...array_api_strict.descriptor.se_t import DescrptSeT as DescrptSeTStrict +else: + DescrptSeTStrict = None from deepmd.utils.argcheck import ( descrpt_se_t_args, ) @@ -91,9 +101,14 @@ def skip_tf(self) -> bool: ) = self.param return env_protection != 0.0 or excluded_types + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_jax = not INSTALLED_JAX + tf_class = DescrptSeTTF dp_class = DescrptSeTDP pt_class = DescrptSeTPT + jax_class = DescrptSeTJAX + array_api_strict_class = DescrptSeTStrict args = descrpt_se_t_args() def setUp(self): @@ -168,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_descriptor( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return self.eval_array_api_strict_descriptor( + array_api_strict_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0],) diff --git a/source/tests/consistent/fitting/test_dipole.py b/source/tests/consistent/fitting/test_dipole.py index 55d6c44c34..60ee7322c1 100644 --- a/source/tests/consistent/fitting/test_dipole.py +++ b/source/tests/consistent/fitting/test_dipole.py @@ -104,7 +104,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_dos.py b/source/tests/consistent/fitting/test_dos.py index 774e3f655e..d3de3ef151 100644 --- a/source/tests/consistent/fitting/test_dos.py +++ b/source/tests/consistent/fitting/test_dos.py @@ -124,7 +124,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index e32410a0ec..f4e78ce966 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -134,7 +134,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_polar.py b/source/tests/consistent/fitting/test_polar.py index 895974baf9..bd9d013b8d 100644 --- a/source/tests/consistent/fitting/test_polar.py +++ b/source/tests/consistent/fitting/test_polar.py @@ -104,7 +104,7 @@ def setUp(self): self.atype.sort() @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/consistent/fitting/test_property.py b/source/tests/consistent/fitting/test_property.py index beb21d9c04..a096d4dd68 100644 --- a/source/tests/consistent/fitting/test_property.py +++ b/source/tests/consistent/fitting/test_property.py @@ -17,6 +17,8 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, INSTALLED_PT, CommonTest, parameterized, @@ -32,6 +34,22 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: PropertyFittingPT = object +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) + from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX +else: + PropertyFittingJAX = object +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict + + from ...array_api_strict.fitting.fitting import ( + PropertyFittingNet as PropertyFittingStrict, + ) +else: + PropertyFittingStrict = object + PropertyFittingTF = object @@ -84,9 +102,14 @@ def skip_pt(self) -> bool: def skip_tf(self) -> bool: return True + skip_jax = not INSTALLED_JAX + skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + tf_class = PropertyFittingTF dp_class = PropertyFittingDP pt_class = PropertyFittingPT + jax_class = PropertyFittingJAX + array_api_strict_class = PropertyFittingStrict args = fitting_property() def setUp(self): @@ -104,7 +127,7 @@ def setUp(self): ).reshape(-1, 1) @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, @@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any: aparam=self.aparam if numb_aparam else None, )["property"] + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if numb_fparam else None, + aparam=jnp.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + array_api_strict.set_array_api_strict_flags(api_version="2023.12") + ( + resnet_dt, + precision, + mixed_types, + numb_fparam, + numb_aparam, + task_dim, + intensive, + ) = self.param + return np.asarray( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None, + aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None, + )["property"] + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: if backend == self.RefBackend.TF: # shape is not same diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 2a358ba7e0..98330ba849 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_pt(data) elif cls is EnergyModelJAX: return get_model_jax(data) - return cls(**data, **self.addtional_data) + return cls(**data, **self.additional_data) def setUp(self): CommonTest.setUp(self) diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py new file mode 100644 index 0000000000..f37bee0c90 --- /dev/null +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.dp_zbl_model import DPZBLModel as DPZBLModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + SKIP_FLAG, + CommonTest, + parameterized, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.dp_zbl_model import DPZBLModel as DPZBLModelPT +else: + DPZBLModelPT = None +import os + +from deepmd.utils.argcheck import ( + model_args, +) + +TESTS_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + +@parameterized( + ( + [], + [[0, 1]], + ), + ( + [], + [1], + ), +) +class TestEner(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + pair_exclude_types, atom_exclude_types = self.param + return { + "type_map": ["O", "H", "B"], + "use_srtab": f"{TESTS_DIR}/pt/water/data/zbl_tab_potential/H2O_tab_potential.txt", + "smin_alpha": 0.1, + "sw_rmin": 0.2, + "sw_rmax": 4.0, + "pair_exclude_types": pair_exclude_types, + "atom_exclude_types": atom_exclude_types, + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.5, + "rcut": 4.0, + "neuron": [3, 6], + "axis_neuron": 2, + "attn": 8, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "activation_function": "tanh", + "scaling_factor": 1.0, + "normalize": False, + "temperature": 1.0, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "seed": 1, + }, + } + + dp_class = DPZBLModelDP + pt_class = DPZBLModelPT + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_tf(self): + return True + + @property + def skip_jax(self): + return True + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is DPZBLModelDP: + return get_model_dp(data) + elif cls is DPZBLModelPT: + return get_model_pt(data) + return cls(**data, **self.additional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.TF: + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/test_type_embedding.py b/source/tests/consistent/test_type_embedding.py index a4b516ef16..0dd17c841e 100644 --- a/source/tests/consistent/test_type_embedding.py +++ b/source/tests/consistent/test_type_embedding.py @@ -82,7 +82,7 @@ def data(self) -> dict: skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT @property - def addtional_data(self) -> dict: + def additional_data(self) -> dict: ( resnet_dt, precision, diff --git a/source/tests/pt/model/test_compressed_descriptor_se_a.py b/source/tests/pt/model/test_compressed_descriptor_se_a.py new file mode 100644 index 0000000000..14d82a452c --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_a.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64"), (True, False)) +class TestDescriptorSeA(unittest.TestCase): + def setUp(self): + (self.dtype, self.type_one_side) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [9, 10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.axis_neuron = 3 + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_a = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + self.neuron, + self.axis_neuron, + type_one_side=self.type_one_side, + seed=21, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_a, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.se_a.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_a, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_compressed_descriptor_se_atten.py b/source/tests/pt/model/test_compressed_descriptor_se_atten.py new file mode 100644 index 0000000000..a439255396 --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_atten.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64"), (True, False)) +class TestDescriptorSeAtten(unittest.TestCase): + def setUp(self): + (self.dtype, self.type_one_side) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.axis_neuron = 3 + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_atten = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.neuron, + self.axis_neuron, + 4, + attn=8, + attn_layer=0, + seed=21, + precision=self.dtype, + type_one_side=self.type_one_side, + tebd_input_mode="strip", + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_atten, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + if self.dtype == "float32": + result_pt = result_pt.to(torch.float32) + elif self.dtype == "float64": + result_pt = result_pt.to(torch.float64) + + self.se_atten.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_atten, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_compressed_descriptor_se_r.py b/source/tests/pt/model/test_compressed_descriptor_se_r.py new file mode 100644 index 0000000000..156cb9a06d --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_r.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64")) +class TestDescriptorSeR(unittest.TestCase): + def setUp(self): + (self.dtype,) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [9, 10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_r = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + self.neuron, + seed=21, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_r, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.se_r.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_r, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_compressed_descriptor_se_t.py b/source/tests/pt/model/test_compressed_descriptor_se_t.py new file mode 100644 index 0000000000..aa3054bc0d --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_t.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64")) +class TestDescriptorSeT(unittest.TestCase): + def setUp(self): + (self.dtype,) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [9, 10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_t = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + self.neuron, + seed=21, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_t, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.se_t.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_t, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py new file mode 100644 index 0000000000..c03773827d --- /dev/null +++ b/source/tests/pt/test_tabulate.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.tabulate import ( + unaggregated_dy2_dx, + unaggregated_dy2_dx_s, + unaggregated_dy_dx, + unaggregated_dy_dx_s, +) +from deepmd.tf.env import ( + op_module, + tf, +) + + +def setUpModule(): + tf.compat.v1.enable_eager_execution() + + +def tearDownModule(): + tf.compat.v1.disable_eager_execution() + + +class TestDPTabulate(unittest.TestCase): + def setUp(self): + self.w = np.array( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], + dtype=np.float64, + ) + + self.x = np.array( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.0, 1.1, 1.2]], + dtype=np.float64, # 4 x 3 + ) + + self.b = np.array([[0.1], [0.2], [0.3], [0.4]], dtype=np.float64) # 4 x 1 + + self.xbar = np.matmul(self.x, self.w) + self.b # 4 x 4 + + self.y = np.tanh(self.xbar) + + def test_ops(self): + dy_tf = op_module.unaggregated_dy_dx_s( + tf.constant(self.y, dtype="double"), + tf.constant(self.w, dtype="double"), + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dy_pt = unaggregated_dy_dx_s( + torch.from_numpy(self.y), + self.w, + torch.from_numpy(self.xbar), + 1, + ) + + dy_tf_numpy = dy_tf.numpy() + dy_pt_numpy = dy_pt.detach().numpy() + + np.testing.assert_almost_equal(dy_tf_numpy, dy_pt_numpy, decimal=10) + + dy2_tf = op_module.unaggregated_dy2_dx_s( + tf.constant(self.y, dtype="double"), + dy_tf, + tf.constant(self.w, dtype="double"), + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dy2_pt = unaggregated_dy2_dx_s( + torch.from_numpy(self.y), + dy_pt, + self.w, + torch.from_numpy(self.xbar), + 1, + ) + + dy2_tf_numpy = dy2_tf.numpy() + dy2_pt_numpy = dy2_pt.detach().numpy() + + np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10) + + dz_tf = op_module.unaggregated_dy_dx( + tf.constant(self.y, dtype="double"), + tf.constant(self.w, dtype="double"), + dy_tf, + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dz_pt = unaggregated_dy_dx( + torch.from_numpy(self.y).to(env.DEVICE), + self.w, + dy_pt, + torch.from_numpy(self.xbar).to(env.DEVICE), + 1, + ) + + dz_tf_numpy = dz_tf.numpy() + dz_pt_numpy = dz_pt.detach().cpu().numpy() + + np.testing.assert_almost_equal(dz_tf_numpy, dz_pt_numpy, decimal=10) + + dy2_tf = op_module.unaggregated_dy2_dx( + tf.constant(self.y, dtype="double"), + tf.constant(self.w, dtype="double"), + dy_tf, + dy2_tf, + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dy2_pt = unaggregated_dy2_dx( + torch.from_numpy(self.y).to(env.DEVICE), + self.w, + dy_pt, + dy2_pt, + torch.from_numpy(self.xbar).to(env.DEVICE), + 1, + ) + + dy2_tf_numpy = dy2_tf.numpy() + dy2_pt_numpy = dy2_pt.detach().cpu().numpy() + + np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10) + + +if __name__ == "__main__": + unittest.main()