diff --git a/README.md b/README.md index a1e9c9484a..81fdead098 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp - [Deep potential long-range](doc/model/dplr.md) - [Deep Potential - Range Correction (DPRc)](doc/model/dprc.md) - [Linear model](doc/model/linear.md) - - [Interpolation with a pairwise potential](doc/model/pairtab.md) + - [Interpolation or combination with a pairwise potential](doc/model/pairtab.md) - [Training](doc/train/index.md) - [Training a model](doc/train/training.md) - [Advanced options](doc/train/training-advanced.md) diff --git a/deepmd/model/model.py b/deepmd/model/model.py index dd439056b4..6117b4942d 100644 --- a/deepmd/model/model.py +++ b/deepmd/model/model.py @@ -97,6 +97,9 @@ def get_class_by_input(cls, input: dict): from deepmd.model.multi import ( MultiModel, ) + from deepmd.model.pairtab import ( + PairTabModel, + ) from deepmd.model.pairwise_dprc import ( PairwiseDPRc, ) @@ -112,6 +115,8 @@ def get_class_by_input(cls, input: dict): return FrozenModel elif model_type == "linear_ener": return LinearEnergyModel + elif model_type == "pairtab": + return PairTabModel else: raise ValueError(f"unknown model type: {model_type}") diff --git a/deepmd/model/pairtab.py b/deepmd/model/pairtab.py new file mode 100644 index 0000000000..38934818e6 --- /dev/null +++ b/deepmd/model/pairtab.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from enum import ( + Enum, +) +from typing import ( + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.env import ( + GLOBAL_TF_FLOAT_PRECISION, + MODEL_VERSION, + global_cvt_2_ener_float, + op_module, + tf, +) +from deepmd.fit.fitting import ( + Fitting, +) +from deepmd.loss.loss import ( + Loss, +) +from deepmd.model.model import ( + Model, +) +from deepmd.utils.pair_tab import ( + PairTab, +) + + +class PairTabModel(Model): + """Pairwise tabulation energy model. + + This model can be used to tabulate the pairwise energy between atoms for either + short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not + be used alone, but rather as one submodel of a linear (sum) model, such as + DP+D3. + + Do not put the model on the first model of a linear model, since the linear + model fetches the type map from the first model. + + At this moment, the model does not smooth the energy at the cutoff radius, so + one needs to make sure the energy has been smoothed to zero. + + Parameters + ---------- + tab_file : str + The path to the tabulation file. + rcut : float + The cutoff radius + sel : int or list[int] + The maxmum number of atoms in the cut-off radius + """ + + model_type = "ener" + + def __init__( + self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs + ): + super().__init__() + self.tab_file = tab_file + self.tab = PairTab(self.tab_file) + self.ntypes = self.tab.ntypes + self.rcut = rcut + if isinstance(sel, int): + self.sel = sel + elif isinstance(sel, list): + self.sel = sum(sel) + else: + raise TypeError("sel must be int or list[int]") + + def build( + self, + coord_: tf.Tensor, + atype_: tf.Tensor, + natoms: tf.Tensor, + box: tf.Tensor, + mesh: tf.Tensor, + input_dict: dict, + frz_model: Optional[str] = None, + ckpt_meta: Optional[str] = None, + suffix: str = "", + reuse: Optional[Union[bool, Enum]] = None, + ): + """Build the model. + + Parameters + ---------- + coord_ : tf.Tensor + The coordinates of atoms + atype_ : tf.Tensor + The atom types of atoms + natoms : tf.Tensor + The number of atoms + box : tf.Tensor + The box vectors + mesh : tf.Tensor + The mesh vectors + input_dict : dict + The input dict + frz_model : str, optional + The path to the frozen model + ckpt_meta : str, optional + The path prefix of the checkpoint and meta files + suffix : str, optional + The suffix of the scope + reuse : bool or tf.AUTO_REUSE, optional + Whether to reuse the variables + + Returns + ------- + dict + The output dict + """ + tab_info, tab_data = self.tab.get() + with tf.variable_scope("model_attr" + suffix, reuse=reuse): + self.tab_info = tf.get_variable( + "t_tab_info", + tab_info.shape, + dtype=tf.float64, + trainable=False, + initializer=tf.constant_initializer(tab_info, dtype=tf.float64), + ) + self.tab_data = tf.get_variable( + "t_tab_data", + tab_data.shape, + dtype=tf.float64, + trainable=False, + initializer=tf.constant_initializer(tab_data, dtype=tf.float64), + ) + t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) + t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) + t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) + + with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): + t_dfparam = tf.constant(0, name="dfparam", dtype=tf.int32) + t_daparam = tf.constant(0, name="daparam", dtype=tf.int32) + with tf.variable_scope("descrpt_attr" + suffix, reuse=reuse): + t_ntypes = tf.constant(self.ntypes, name="ntypes", dtype=tf.int32) + t_rcut = tf.constant( + self.rcut, name="rcut", dtype=GLOBAL_TF_FLOAT_PRECISION + ) + coord = tf.reshape(coord_, [-1, natoms[1] * 3]) + atype = tf.reshape(atype_, [-1, natoms[1]]) + box = tf.reshape(box, [-1, 9]) + # perhaps we need a OP that only outputs rij and nlist + ( + _, + _, + rij, + nlist, + _, + _, + ) = op_module.prod_env_mat_a_mix( + coord, + atype, + natoms, + box, + mesh, + np.zeros([self.ntypes, self.sel * 4]), + np.ones([self.ntypes, self.sel * 4]), + rcut_a=-1, + rcut_r=self.rcut, + rcut_r_smth=self.rcut, + sel_a=[self.sel], + sel_r=[0], + ) + scale = tf.ones([tf.shape(coord)[0], natoms[0]], dtype=tf.float64) + tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab( + self.tab_info, + self.tab_data, + atype, + rij, + nlist, + natoms, + scale, + sel_a=[self.sel], + sel_r=[0], + ) + energy_raw = tf.reshape( + tab_atom_ener, [-1, natoms[0]], name="o_atom_energy" + suffix + ) + energy = tf.reduce_sum( + global_cvt_2_ener_float(energy_raw), axis=1, name="o_energy" + suffix + ) + force = tf.reshape(tab_force, [-1, 3 * natoms[1]], name="o_force" + suffix) + virial = tf.reshape( + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1), + [-1, 9], + name="o_virial" + suffix, + ) + atom_virial = tf.reshape( + tab_atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix + ) + model_dict = {} + model_dict["energy"] = energy + model_dict["force"] = force + model_dict["virial"] = virial + model_dict["atom_ener"] = energy_raw + model_dict["atom_virial"] = atom_virial + model_dict["coord"] = coord + model_dict["atype"] = atype + + return model_dict + + def init_variables( + self, + graph: tf.Graph, + graph_def: tf.GraphDef, + model_type: str = "original_model", + suffix: str = "", + ) -> None: + """Init the embedding net variables with the given frozen model. + + Parameters + ---------- + graph : tf.Graph + The input frozen model graph + graph_def : tf.GraphDef + The input frozen model graph_def + model_type : str + the type of the model + suffix : str + suffix to name scope + """ + # skip. table can be initialized from the file + + def get_fitting(self) -> Union[Fitting, dict]: + """Get the fitting(s).""" + # nothing needs to do + return {} + + def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]: + """Get the loss function(s).""" + # nothing nees to do + return + + def get_rcut(self) -> float: + """Get cutoff radius of the model.""" + return self.rcut + + def get_ntypes(self) -> int: + """Get the number of types.""" + return self.ntypes + + def data_stat(self, data: dict): + """Data staticis.""" + # nothing needs to do + + def enable_compression(self, suffix: str = "") -> None: + """Enable compression. + + Parameters + ---------- + suffix : str + suffix to name scope + """ + # nothing needs to do + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: + """Update the selection and perform neighbor statistics. + + Notes + ----- + Do not modify the input data without copying it. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + """ + from deepmd.entrypoints.train import ( + update_one_sel, + ) + + local_jdata_cpy = local_jdata.copy() + return update_one_sel(global_jdata, local_jdata_cpy, True) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8d09d25577..2c1d235801 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -927,6 +927,7 @@ def model_args(exclude_hybrid=False): standard_model_args(), multi_model_args(), frozen_model_args(), + pairtab_model_args(), *hybrid_models, ], optional=True, @@ -1013,6 +1014,26 @@ def frozen_model_args() -> Argument: return ca +def pairtab_model_args() -> Argument: + doc_tab_file = "Path to the tabulation file." + doc_rcut = "The cut-off radius." + doc_sel = 'This parameter set the number of selected neighbors. Note that this parameter is a little different from that in other descriptors. Instead of separating each type of atoms, only the summation matters. And this number is highly related with the efficiency, thus one should not make it too large. Usually 200 or less is enough, far away from the GPU limitation 4096. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. Only the summation of `sel[i]` matters, and it is recommended to be less than 200.\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wraped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + ca = Argument( + "pairtab", + dict, + [ + Argument("tab_file", str, optional=False, doc=doc_tab_file), + Argument("rcut", float, optional=False, doc=doc_rcut), + Argument("sel", [int, List[int], str], optional=False, doc=doc_sel), + ], + doc="Pairwise tabulation energy model.", + ) + return ca + + def linear_ener_model_args() -> Argument: doc_weights = ( "If the type is list of float, a list of weights for each model. " diff --git a/doc/model/index.md b/doc/model/index.md index 6c128028a6..589b39b2b5 100644 --- a/doc/model/index.md +++ b/doc/model/index.md @@ -17,4 +17,4 @@ - [Deep potential long-range](dplr.md) - [Deep Potential - Range Correction (DPRc)](dprc.md) - [Linear model](linear.md) -- [Interpolation with a pairwise potential](pairtab.md) +- [Interpolation or combination with a pairwise potential](pairtab.md) diff --git a/doc/model/pairtab.md b/doc/model/pairtab.md index e3f0118f2c..115345796a 100644 --- a/doc/model/pairtab.md +++ b/doc/model/pairtab.md @@ -1,4 +1,4 @@ -# Interpolation with a pairwise potential +# Interpolation or combination with a pairwise potential ## Theory In applications like the radiation damage simulation, the interatomic distance may become too close, so that the DFT calculations fail. @@ -33,3 +33,56 @@ where the scale $\alpha_s$ is a tunable scale of the interatomic distance $r_{ij The pairwise potential $u^{\textrm{pair}}(r)$ is defined by a user-defined table that provides the value of $u^{\textrm{pair}}$ on an evenly discretized grid from 0 to the cutoff distance.[^1] [^1]: This section is built upon Jinzhe Zeng, Duo Zhang, Denghui Lu, Pinghui Mo, Zeyu Li, Yixiao Chen, Marián Rynik, Li'ang Huang, Ziyao Li, Shaochen Shi, Yingze Wang, Haotian Ye, Ping Tuo, Jiabin Yang, Ye Ding, Yifan Li, Davide Tisi, Qiyu Zeng, Han Bao, Yu Xia, Jiameng Huang, Koki Muraoka, Yibo Wang, Junhan Chang, Fengbo Yuan, Sigbjørn Løland Bore, Chun Cai, Yinnian Lin, Bo Wang, Jiayan Xu, Jia-Xin Zhu, Chenxing Luo, Yuzhi Zhang, Rhys E. A. Goodall, Wenshuo Liang, Anurag Kumar Singh, Sikai Yao, Jingchao Zhang, Renata Wentzcovitch, Jiequn Han, Jie Liu, Weile Jia, Darrin M. York, Weinan E, Roberto Car, Linfeng Zhang, Han Wang, [J. Chem. Phys. 159, 054801 (2023)](https://doi.org/10.1063/5.0155600) licensed under a [Creative Commons Attribution (CC BY) license](http://creativecommons.org/licenses/by/4.0/). + +DeePMD-kit also supports combination with a pairwise potential: + +```math + E_i = E_i^{\mathrm{DP}} + E_i^{\mathrm{pair}}, +``` + +## Table file + +The table file should be a text file that can be read by {py:meth}`numpy.loadtxt`. +The first column is the distance between two atoms, where upper range should be larger than the cutoff radius. +Other columns are two-body interaction energies for pairs of certain types, +in the order of Type_0-Type_0, Type_0-Type_1, ..., Type_0-Type_N, Type_1-Type_1, ..., Type_1-Type_N, ..., and Type_N-Type_N. + +The interaction should be smooth at the cut-off distance. + +## Interpolation with a short-range pairwise potential + +```json +"model": { + "use_srtab": "H2O_tab_potential.txt", + "smin_alpha": 0.1, + "sw_rmin": 0.8, + "sw_rmax": 1.0, + "_comment": "Below uses a normal DP model" +} +``` + +{ref}`sw_rmin ` and {ref}`sw_rmax ` must be smaller than the cutoff radius of the DP model. + +## Combination with a pairwise potential + +To combine with a pairwise potential, use the [linear model](./linear.md): + +```json +"model": { + "type": "linear_ener", + "weights": "sum", + "models": [ + { + "_comment": "Here uses a normal DP model" + }, + { + "type": "pairtab", + "tab_file": "dftd3.txt", + "rcut": 10.0, + "sel": 534 + } + ] +} +``` + +The {ref}`rcut ` can be larger than that of the DP model. diff --git a/examples/water/d3/README.md b/examples/water/d3/README.md new file mode 100644 index 0000000000..bd75960010 --- /dev/null +++ b/examples/water/d3/README.md @@ -0,0 +1,11 @@ +# DPD3 + +`dftd3.txt` tabulates D3 dispersion for each pair of types (O-O, O-H, H-H). +It can be generated by [simple-dftd3](https://github.com/dftd3/simple-dftd3). + +## Note + +As an example, it cannot be used in production: + +- For small file sizes in the repository, the distance interval in the tabulation is only 0.1. +- The example training data does not contain dispersion interaction. diff --git a/examples/water/d3/dftd3.txt b/examples/water/d3/dftd3.txt new file mode 100644 index 0000000000..bbc9726134 --- /dev/null +++ b/examples/water/d3/dftd3.txt @@ -0,0 +1,100 @@ +1.000000000000000056e-01 -5.836993924755046366e-03 -3.207255698139210940e-03 -1.843064837882633228e-03 +2.000000000000000111e-01 -5.836993806911452108e-03 -3.207255613696154226e-03 -1.843064776130543892e-03 +3.000000000000000444e-01 -5.836992560106194113e-03 -3.207254720510349828e-03 -1.843064123123401392e-03 +4.000000000000000222e-01 -5.836986225627246658e-03 -3.207250184384043221e-03 -1.843060811677158526e-03 +5.000000000000000000e-01 -5.836964436915091821e-03 -3.207234589497737730e-03 -1.843052788205641135e-03 +5.999999999999999778e-01 -5.836905460107320170e-03 -3.207192410957825698e-03 -1.843338972660025360e-03 +7.000000000000000666e-01 -5.836769626930583300e-03 -3.207096085246822614e-03 -1.851839876215982238e-03 +8.000000000000000444e-01 -5.836491030513121618e-03 -3.206924889333430135e-03 -2.035200426069873857e-03 +9.000000000000000222e-01 -5.835967602710929840e-03 -3.206999537190755728e-03 -3.724418810291191088e-03 +1.000000000000000000e+00 -5.835053775792304297e-03 -3.210477055685919626e-03 -4.311009958284344433e-03 +1.100000000000000089e+00 -5.833591489567684953e-03 -3.237527828601436623e-03 -4.381510573223419171e-03 +1.200000000000000178e+00 -5.831652981781070173e-03 -3.454845258034439960e-03 -4.394419437232751843e-03 +1.300000000000000266e+00 -5.830520601296543433e-03 -4.478070067533340692e-03 -4.394683688871586433e-03 +1.400000000000000133e+00 -5.835353622834494637e-03 -5.097530655625692915e-03 -4.389691198859401421e-03 +1.500000000000000222e+00 -5.863290690264541874e-03 -5.215500241204417201e-03 -4.380686516072217034e-03 +1.600000000000000089e+00 -6.007605076700822840e-03 -5.234994618743306349e-03 -4.367337507268855175e-03 +1.700000000000000178e+00 -6.481613230242359684e-03 -5.228094160806716871e-03 -4.348706108547779198e-03 +1.800000000000000266e+00 -6.814114687600298335e-03 -5.208252365588400719e-03 -4.323505520547227775e-03 +1.900000000000000133e+00 -6.876286379079538276e-03 -5.177988357772074675e-03 -4.290186895355558444e-03 +2.000000000000000000e+00 -6.858440816799354217e-03 -5.136887568332395605e-03 -4.246989919717190920e-03 +2.100000000000000089e+00 -6.810730159155128395e-03 -5.083475665301987606e-03 -4.192000168715152505e-03 +2.200000000000000178e+00 -6.742330737387775344e-03 -5.015815334399144516e-03 -4.123231519970332187e-03 +2.300000000000000266e+00 -6.653841351238824232e-03 -4.931782661310191510e-03 -4.038743210125123918e-03 +2.400000000000000355e+00 -6.543651317938833402e-03 -4.829269294496830317e-03 -3.936795390727530070e-03 +2.500000000000000444e+00 -6.409559281498313811e-03 -4.706385522261587705e-03 -3.816040239463167755e-03 +2.600000000000000089e+00 -6.249406635892575460e-03 -4.561685215972477100e-03 -3.675736338668155346e-03 +2.700000000000000178e+00 -6.061478463281754457e-03 -4.394408172892586353e-03 -3.515962176363645990e-03 +2.800000000000000266e+00 -5.844844934626365965e-03 -4.204716954930251029e-03 -3.337792190764940319e-03 +2.900000000000000355e+00 -5.599669004675433479e-03 -3.993889719587391009e-03 -3.143390268473208755e-03 +3.000000000000000444e+00 -5.327453506642119106e-03 -3.764420755089863558e-03 -2.935977648106832729e-03 +3.100000000000000089e+00 -5.031178000843260223e-03 -3.519982860915751074e-03 -2.719650568099894056e-03 +3.200000000000000178e+00 -4.715273672783852794e-03 -3.265225882759082918e-03 -2.499057451653833965e-03 +3.300000000000000266e+00 -4.385404785641488362e-03 -3.005422601424333727e-03 -2.278985743812388717e-03 +3.400000000000000355e+00 -4.048065433713449700e-03 -2.746015696661484231e-03 -2.063937321866260270e-03 +3.500000000000000444e+00 -3.710048572169818114e-03 -2.492149763588673555e-03 -1.857774171128685628e-03 +3.600000000000000089e+00 -3.377881092113224713e-03 -2.248275746149775312e-03 -1.663491260531681313e-03 +3.700000000000000178e+00 -3.057327225182689644e-03 -2.017890114824574810e-03 -1.483133951195727196e-03 +3.800000000000000266e+00 -2.753038981057491941e-03 -1.803430168074075671e-03 -1.317840750738439540e-03 +3.900000000000000355e+00 -2.468388171389931940e-03 -1.606308000309067743e-03 -1.167971059502070875e-03 +4.000000000000000000e+00 -2.205469013267805957e-03 -1.427041871266797194e-03 -1.033273795673775699e-03 +4.099999999999999645e+00 -1.965228953751702902e-03 -1.265437879541002862e-03 -9.130610310879381641e-04 +4.200000000000000178e+00 -1.747673832278765806e-03 -1.120782158543769547e-03 -8.063636493380576522e-04 +4.299999999999999822e+00 -1.552098284175109895e-03 -9.920168984562682292e-04 -7.120580835032176920e-04 +4.399999999999999467e+00 -1.377305748647780163e-03 -8.778864597897169646e-04 -6.289618864203703032e-04 +4.500000000000000000e+00 -1.221797526507303194e-03 -7.770496638083513111e-04 -5.559009474092405914e-04 +4.599999999999999645e+00 -1.083922782809847944e-03 -6.881603844395511003e-04 -4.917533939693695443e-04 +4.700000000000000178e+00 -9.619897379282633162e-04 -6.099214740721333600e-04 -4.354756390957214944e-04 +4.799999999999999822e+00 -8.543428352989788704e-04 -5.411178648690499965e-04 -3.861155118068372257e-04 +4.900000000000000355e+00 -7.594124385866309881e-04 -4.806343247547230249e-04 -3.428165131289927659e-04 +5.000000000000000000e+00 -6.757436744162991990e-04 -4.274624687438948085e-04 -3.048162971647301774e-04 +5.099999999999999645e+00 -6.020102408497160842e-04 -3.807006248475114439e-04 -2.714416410742632600e-04 +5.200000000000000178e+00 -5.370178955485286568e-04 -3.395492294862310413e-04 -2.421014916366724180e-04 +5.299999999999999822e+00 -4.797012289428498875e-04 -3.033036596191310643e-04 -2.162791601488694472e-04 +5.400000000000000355e+00 -4.291163603974148220e-04 -2.713458112672340397e-04 -1.935243599692976007e-04 +5.500000000000000000e+00 -3.844314156775488251e-04 -2.431352896687036106e-04 -1.734455139070628909e-04 +5.599999999999999645e+00 -3.449160478270333653e-04 -2.182007570692257958e-04 -1.557025751017268144e-04 +5.700000000000000178e+00 -3.099308250478081581e-04 -1.961317615550248216e-04 -1.400004825033046053e-04 +5.799999999999999822e+00 -2.789169965744946232e-04 -1.765712194195623135e-04 -1.260832928664179986e-04 +5.900000000000000355e+00 -2.513869308376957498e-04 -1.592086242469989389e-04 -1.137289820430557137e-04 +6.000000000000000000e+00 -2.269153740910770769e-04 -1.437739928526920498e-04 -1.027448796326863424e-04 +6.099999999999999645e+00 -2.051315821421489645e-04 -1.300325201124615134e-04 -9.296368585686617101e-05 +6.200000000000000178e+00 -1.857123177371916057e-04 -1.177798933810950252e-04 -8.424001307773229020e-05 +6.299999999999999822e+00 -1.683756703844696025e-04 -1.068382068924710331e-04 -7.644739339328402133e-05 +6.400000000000000355e+00 -1.528756359693242027e-04 -9.705241326038571551e-05 -6.947569600606938272e-05 +6.500000000000000000e+00 -1.389973847900246836e-04 -8.828725024164000819e-05 -6.322890211399061677e-05 +6.599999999999999645e+00 -1.265531447216864910e-04 -8.042458445906290783e-05 -5.762318996708282935e-05 +6.700000000000000178e+00 -1.153786284350462083e-04 -7.336111861455703732e-05 -5.258528787829301387e-05 +6.799999999999999822e+00 -1.053299381724164837e-04 -6.700641408290249014e-05 -4.805105801105378322e-05 +6.900000000000000355e+00 -9.628088734156424651e-05 -6.128118618925484863e-05 -4.396427848897285724e-05 +7.000000000000000000e+00 -8.812068437769617318e-05 -5.611583465513190913e-05 -4.027559568117881135e-05 +7.099999999999999645e+00 -8.075193047879847589e-05 -5.144917649553730292e-05 -3.694162237238345173e-05 +7.200000000000000178e+00 -7.408888866698059216e-05 -4.722735299269381154e-05 -3.392416093003276399e-05 +7.299999999999999822e+00 -6.805598702152939358e-05 -4.340288624040664528e-05 -3.118953355495664799e-05 +7.400000000000000355e+00 -6.258652380321327402e-05 -3.993386415929820209e-05 -2.870800428147100793e-05 +7.500000000000000000e+00 -5.762154653038724025e-05 -3.678323585657680581e-05 -2.645327961777798337e-05 +7.599999999999999645e+00 -5.310888089285451013e-05 -3.391820178292012157e-05 -2.440207662848582849e-05 +7.700000000000000178e+00 -4.900228873380196631e-05 -3.130968536501008690e-05 -2.253374889733842244e-05 +7.799999999999999822e+00 -4.526073723647751752e-05 -2.893187470658392955e-05 -2.082996220611386151e-05 +7.900000000000000355e+00 -4.184776396661089387e-05 -2.676182459276817884e-05 -1.927441295794754013e-05 +8.000000000000000000e+00 -3.873092458939377268e-05 -2.477911043795883125e-05 -1.785258338919504348e-05 +8.099999999999999645e+00 -3.588131194417033489e-05 -2.296552701898519263e-05 -1.655152847892165657e-05 +8.199999999999999289e+00 -3.327313676038550535e-05 -2.130482586144845337e-05 -1.535969020138178571e-05 +8.300000000000000711e+00 -3.088336167038252842e-05 -1.978248602307972753e-05 -1.426673539356727330e-05 +8.400000000000000355e+00 -2.869138134992016182e-05 -1.838551376555334571e-05 -1.326341404347894562e-05 +8.500000000000000000e+00 -2.667874262351516647e-05 -1.710226724425689568e-05 -1.234143525923654349e-05 +8.599999999999999645e+00 -2.482889923305170253e-05 -1.592230289025118773e-05 -1.149335856644029766e-05 +8.699999999999999289e+00 -2.312699670543422217e-05 -1.483624062390156494e-05 -1.071249851148625846e-05 +8.800000000000000711e+00 -2.155968338640409072e-05 -1.383564543725925445e-05 -9.992840830439041262e-06 +8.900000000000000355e+00 -2.011494424844594071e-05 -1.291292322228744002e-05 -9.328968683880578804e-06 +9.000000000000000000e+00 -1.878195454422273937e-05 -1.206122901303710759e-05 -8.715997664073560640e-06 +9.099999999999999645e+00 -1.755095077450199208e-05 -1.127438605916122122e-05 -8.149518457037649769e-06 +9.199999999999999289e+00 -1.641311678073074286e-05 -1.054681436190059537e-05 -7.625546193173260002e-06 +9.300000000000000711e+00 -1.536048306550537418e-05 -9.873467487129955082e-06 -7.140475649634374041e-06 +9.400000000000000355e+00 -1.438583769617946587e-05 -9.249776627676899564e-06 -6.691041578929334717e-06 +9.500000000000000000e+00 -1.348264736372320039e-05 -8.671601022702781346e-06 -6.274283533910064576e-06 +9.599999999999999645e+00 -1.264498735578012246e-05 -8.135183958678718279e-06 -5.887514641681122086e-06 +9.700000000000001066e+00 -1.186747936398473687e-05 -7.637113677130612127e-06 -5.528293849956352819e-06 +9.800000000000000711e+00 -1.114523618469756001e-05 -7.174288601187318493e-06 -5.194401230658985063e-06 +9.900000000000000355e+00 -1.047381249252528874e-05 -6.743886368019750717e-06 -4.883815978498405921e-06 +1.000000000000000000e+01 0.000000000000000e00e+00 0.000000000000000e00e+00 0.000000000000000e00e+00 diff --git a/examples/water/d3/input.json b/examples/water/d3/input.json new file mode 100644 index 0000000000..bbe7a2c8a9 --- /dev/null +++ b/examples/water/d3/input.json @@ -0,0 +1,95 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "linear_ener", + "weights": "sum", + "models": [ + { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "precision": "float64", + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + }, + { + "type": "pairtab", + "tab_file": "dftd3.txt", + "rcut": 10.0, + "sel": 534 + } + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment8": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment9": "that's all" + }, + "_comment10": "that's all" +} diff --git a/source/tests/test_model_pairtab.py b/source/tests/test_model_pairtab.py new file mode 100644 index 0000000000..fd678894b5 --- /dev/null +++ b/source/tests/test_model_pairtab.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import scipy.spatial.distance +from common import ( + DataSystem, + gen_data, + j_loader, +) + +from deepmd.common import ( + j_must_have, +) +from deepmd.env import ( + tf, +) +from deepmd.model.model import ( + Model, +) + +GLOBAL_ENER_FLOAT_PRECISION = tf.float64 +GLOBAL_TF_FLOAT_PRECISION = tf.float64 +GLOBAL_NP_FLOAT_PRECISION = np.float64 + + +class TestModel(tf.test.TestCase): + def setUp(self): + gen_data() + + def test_model(self): + jfile = "water.json" + jdata = j_loader(jfile) + systems = j_must_have(jdata, "systems") + set_pfx = j_must_have(jdata, "set_prefix") + batch_size = 1 + test_size = 1 + + tab_filename = "test_pairtab_tab.txt" + jdata["model"] = { + "type": "pairtab", + "tab_file": tab_filename, + "rcut": 6, + "sel": [6], + } + rcut = j_must_have(jdata["model"], "rcut") + + def pair_pot(r: float): + # LJ, as exmaple + return 4 * (1 / r**12 - 1 / r**6) + + dx = 1e-4 + d = np.arange(dx, rcut + dx, dx) + tab = np.array( + [ + d, + pair_pot(d), + pair_pot(d), + pair_pot(d), + ] + ).T + np.savetxt(tab_filename, tab) + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) + + test_data = data.get_test() + numb_test = 1 + + model = Model( + **jdata["model"], + ) + + t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") + t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy") + t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force") + t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial") + t_atom_ener = tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener" + ) + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") + t_type = tf.placeholder(tf.int32, [None], name="i_type") + t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") + is_training = tf.placeholder(tf.bool) + t_fparam = None + + model_pred = model.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + t_fparam, + suffix="test_pairtab", + reuse=False, + ) + energy = model_pred["energy"] + force = model_pred["force"] + virial = model_pred["virial"] + + feed_dict_test = { + t_prop_c: test_data["prop_c"], + t_energy: test_data["energy"][:numb_test], + t_force: np.reshape(test_data["force"][:numb_test, :], [-1]), + t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]), + t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]), + t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), + t_box: test_data["box"][:numb_test, :], + t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), + t_natoms: test_data["natoms_vec"], + t_mesh: [], # nopbc + is_training: False, + } + + with self.cached_session() as sess: + sess.run(tf.global_variables_initializer()) + [e, _, _] = sess.run([energy, force, virial], feed_dict=feed_dict_test) + + e = e.reshape([-1]) + + coord = test_data["coord"][0, :].reshape(-1, 3) + distance = scipy.spatial.distance.cdist(coord, coord).ravel() + refe = [np.sum(pair_pot(distance[np.nonzero(distance)])) / 2] + + refe = np.reshape(refe, [-1]) + + places = 10 + np.testing.assert_almost_equal(e, refe, places)