From f47dc669b466da153f4b615187ae330bea25844f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 1 Jan 2024 10:07:30 -0500 Subject: [PATCH] Add pairwise tabulation as an independent model Add pairwise tabulation as an independent model, which can be summed with DP (DP + PairTab) by the linear model, other than interpolation. PairTab can be used for any pairwise potentials, e.g., d3, LJ, ZBL, etc. Signed-off-by: Jinzhe Zeng --- deepmd/model/model.py | 5 + deepmd/model/pairtab.py | 274 +++++++++++++++++++++++++++++ deepmd/utils/argcheck.py | 21 +++ examples/water/d3/README.md | 11 ++ examples/water/d3/dftd3.txt | 100 +++++++++++ examples/water/d3/input.json | 95 ++++++++++ source/tests/test_model_pairtab.py | 131 ++++++++++++++ 7 files changed, 637 insertions(+) create mode 100644 deepmd/model/pairtab.py create mode 100644 examples/water/d3/README.md create mode 100644 examples/water/d3/dftd3.txt create mode 100644 examples/water/d3/input.json create mode 100644 source/tests/test_model_pairtab.py diff --git a/deepmd/model/model.py b/deepmd/model/model.py index dd439056b4..6117b4942d 100644 --- a/deepmd/model/model.py +++ b/deepmd/model/model.py @@ -97,6 +97,9 @@ def get_class_by_input(cls, input: dict): from deepmd.model.multi import ( MultiModel, ) + from deepmd.model.pairtab import ( + PairTabModel, + ) from deepmd.model.pairwise_dprc import ( PairwiseDPRc, ) @@ -112,6 +115,8 @@ def get_class_by_input(cls, input: dict): return FrozenModel elif model_type == "linear_ener": return LinearEnergyModel + elif model_type == "pairtab": + return PairTabModel else: raise ValueError(f"unknown model type: {model_type}") diff --git a/deepmd/model/pairtab.py b/deepmd/model/pairtab.py new file mode 100644 index 0000000000..dda605a7aa --- /dev/null +++ b/deepmd/model/pairtab.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from enum import ( + Enum, +) +from typing import ( + List, + Optional, + Union, +) + +import numpy as np + +from deepmd.env import ( + global_cvt_2_ener_float, + op_module, + tf, +) +from deepmd.fit.fitting import ( + Fitting, +) +from deepmd.loss.loss import ( + Loss, +) +from deepmd.model.model import ( + Model, +) +from deepmd.utils.pair_tab import ( + PairTab, +) + + +class PairTabModel(Model): + """Pairwise tabulation energy model. + + This model can be used to tabulate the pairwise energy between atoms for either + short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not + be used alone, but rather as one submodel of a linear (sum) model, such as + DP+D3. + + Do not put the model on the first model of a linear model, since the linear + model fetches the type map from the first model. + + At this moment, the model does not smooth the energy at the cutoff radius, so + one needs to make sure the energy has been smoothed to zero. + + Parameters + ---------- + tab_file : str + The path to the tabulation file. + rcut : float + The cutoff radius + sel : int or list[int] + The maxmum number of atoms in the cut-off radius + """ + + model_type = "ener" + + def __init__( + self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs + ): + super().__init__() + self.tab_file = tab_file + self.tab = PairTab(self.tab_file) + self.ntypes = self.tab.ntypes + self.rcut = rcut + if isinstance(sel, int): + self.sel = sel + elif isinstance(sel, list): + self.sel = sum(sel) + else: + raise TypeError("sel must be int or list[int]") + + def build( + self, + coord_: tf.Tensor, + atype_: tf.Tensor, + natoms: tf.Tensor, + box: tf.Tensor, + mesh: tf.Tensor, + input_dict: dict, + frz_model: Optional[str] = None, + ckpt_meta: Optional[str] = None, + suffix: str = "", + reuse: Optional[Union[bool, Enum]] = None, + ): + """Build the model. + + Parameters + ---------- + coord_ : tf.Tensor + The coordinates of atoms + atype_ : tf.Tensor + The atom types of atoms + natoms : tf.Tensor + The number of atoms + box : tf.Tensor + The box vectors + mesh : tf.Tensor + The mesh vectors + input_dict : dict + The input dict + frz_model : str, optional + The path to the frozen model + ckpt_meta : str, optional + The path prefix of the checkpoint and meta files + suffix : str, optional + The suffix of the scope + reuse : bool or tf.AUTO_REUSE, optional + Whether to reuse the variables + + Returns + ------- + dict + The output dict + """ + tab_info, tab_data = self.tab.get() + with tf.variable_scope("model_attr" + suffix, reuse=reuse): + self.tab_info = tf.get_variable( + "t_tab_info", + tab_info.shape, + dtype=tf.float64, + trainable=False, + initializer=tf.constant_initializer(tab_info, dtype=tf.float64), + ) + self.tab_data = tf.get_variable( + "t_tab_data", + tab_data.shape, + dtype=tf.float64, + trainable=False, + initializer=tf.constant_initializer(tab_data, dtype=tf.float64), + ) + coord = tf.reshape(coord_, [-1, natoms[1] * 3]) + atype = tf.reshape(atype_, [-1, natoms[1]]) + box = tf.reshape(box, [-1, 9]) + # perhaps we need a OP that only outputs rij and nlist + ( + _, + _, + rij, + nlist, + _, + _, + ) = op_module.prod_env_mat_a_mix( + coord, + atype, + natoms, + box, + mesh, + np.zeros([self.ntypes, self.sel * 4]), + np.ones([self.ntypes, self.sel * 4]), + rcut_a=-1, + rcut_r=self.rcut, + rcut_r_smth=self.rcut, + sel_a=[self.sel], + sel_r=[0], + ) + scale = tf.ones([tf.shape(coord)[0], natoms[0]], dtype=tf.float64) + tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab( + self.tab_info, + self.tab_data, + atype, + rij, + nlist, + natoms, + scale, + sel_a=[self.sel], + sel_r=[0], + ) + energy_raw = tf.reshape( + tab_atom_ener, [-1, natoms[0]], name="o_atom_energy" + suffix + ) + energy = tf.reduce_sum( + global_cvt_2_ener_float(energy_raw), axis=1, name="o_energy" + suffix + ) + force = tf.reshape(tab_force, [-1, 3 * natoms[1]], name="o_force" + suffix) + virial = tf.reshape( + tf.reduce_sum(tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1), + [-1, 9], + name="o_virial" + suffix, + ) + atom_virial = tf.reshape( + tab_atom_virial, [-1, 9 * natoms[1]], name="o_atom_virial" + suffix + ) + model_dict = {} + model_dict["energy"] = energy + model_dict["force"] = force + model_dict["virial"] = virial + model_dict["atom_ener"] = energy_raw + model_dict["atom_virial"] = atom_virial + model_dict["coord"] = coord + model_dict["atype"] = atype + + return model_dict + + def init_variables( + self, + graph: tf.Graph, + graph_def: tf.GraphDef, + model_type: str = "original_model", + suffix: str = "", + ) -> None: + """Init the embedding net variables with the given frozen model. + + Parameters + ---------- + graph : tf.Graph + The input frozen model graph + graph_def : tf.GraphDef + The input frozen model graph_def + model_type : str + the type of the model + suffix : str + suffix to name scope + """ + # skip. table can be initialized from the file + + def get_fitting(self) -> Union[Fitting, dict]: + """Get the fitting(s).""" + # nothing needs to do + return {} + + def get_loss(self, loss: dict, lr) -> Optional[Union[Loss, dict]]: + """Get the loss function(s).""" + # nothing nees to do + return + + def get_rcut(self) -> float: + """Get cutoff radius of the model.""" + return self.rcut + + def get_ntypes(self) -> int: + """Get the number of types.""" + return self.ntypes + + def data_stat(self, data: dict): + """Data staticis.""" + # nothing needs to do + + def enable_compression(self, suffix: str = "") -> None: + """Enable compression. + + Parameters + ---------- + suffix : str + suffix to name scope + """ + # nothing needs to do + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: + """Update the selection and perform neighbor statistics. + + Notes + ----- + Do not modify the input data without copying it. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + """ + from deepmd.entrypoints.train import ( + update_one_sel, + ) + + local_jdata_cpy = local_jdata.copy() + return update_one_sel(global_jdata, local_jdata_cpy, True) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8d09d25577..c8945cee74 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -927,6 +927,7 @@ def model_args(exclude_hybrid=False): standard_model_args(), multi_model_args(), frozen_model_args(), + pairtab_model_args(), *hybrid_models, ], optional=True, @@ -1013,6 +1014,26 @@ def frozen_model_args() -> Argument: return ca +def pairtab_model_args() -> Argument: + doc_tab_file = "Path to the tabulation file." + doc_rcut = "The cutoff radius." + doc_sel = 'This parameter set the number of selected neighbors. Note that this parameter is a little different from that in other descriptors. Instead of separating each type of atoms, only the summation matters. And this number is highly related with the efficiency, thus one should not make it too large. Usually 200 or less is enough, far away from the GPU limitation 4096. It can be:\n\n\ + - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ + - `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. Only the summation of `sel[i]` matters, and it is recommended to be less than 200.\ + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wraped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + ca = Argument( + "pairtab", + dict, + [ + Argument("tab_file", str, optional=False, doc=doc_tab_file), + Argument("rcut", float, optional=False, doc="The cutoff radius."), + Argument("sel", [int, List[int], str], optional=False, doc=doc_sel), + ], + doc="Pairwise tabulation energy model.", + ) + return ca + + def linear_ener_model_args() -> Argument: doc_weights = ( "If the type is list of float, a list of weights for each model. " diff --git a/examples/water/d3/README.md b/examples/water/d3/README.md new file mode 100644 index 0000000000..bd75960010 --- /dev/null +++ b/examples/water/d3/README.md @@ -0,0 +1,11 @@ +# DPD3 + +`dftd3.txt` tabulates D3 dispersion for each pair of types (O-O, O-H, H-H). +It can be generated by [simple-dftd3](https://github.com/dftd3/simple-dftd3). + +## Note + +As an example, it cannot be used in production: + +- For small file sizes in the repository, the distance interval in the tabulation is only 0.1. +- The example training data does not contain dispersion interaction. diff --git a/examples/water/d3/dftd3.txt b/examples/water/d3/dftd3.txt new file mode 100644 index 0000000000..bbc9726134 --- /dev/null +++ b/examples/water/d3/dftd3.txt @@ -0,0 +1,100 @@ +1.000000000000000056e-01 -5.836993924755046366e-03 -3.207255698139210940e-03 -1.843064837882633228e-03 +2.000000000000000111e-01 -5.836993806911452108e-03 -3.207255613696154226e-03 -1.843064776130543892e-03 +3.000000000000000444e-01 -5.836992560106194113e-03 -3.207254720510349828e-03 -1.843064123123401392e-03 +4.000000000000000222e-01 -5.836986225627246658e-03 -3.207250184384043221e-03 -1.843060811677158526e-03 +5.000000000000000000e-01 -5.836964436915091821e-03 -3.207234589497737730e-03 -1.843052788205641135e-03 +5.999999999999999778e-01 -5.836905460107320170e-03 -3.207192410957825698e-03 -1.843338972660025360e-03 +7.000000000000000666e-01 -5.836769626930583300e-03 -3.207096085246822614e-03 -1.851839876215982238e-03 +8.000000000000000444e-01 -5.836491030513121618e-03 -3.206924889333430135e-03 -2.035200426069873857e-03 +9.000000000000000222e-01 -5.835967602710929840e-03 -3.206999537190755728e-03 -3.724418810291191088e-03 +1.000000000000000000e+00 -5.835053775792304297e-03 -3.210477055685919626e-03 -4.311009958284344433e-03 +1.100000000000000089e+00 -5.833591489567684953e-03 -3.237527828601436623e-03 -4.381510573223419171e-03 +1.200000000000000178e+00 -5.831652981781070173e-03 -3.454845258034439960e-03 -4.394419437232751843e-03 +1.300000000000000266e+00 -5.830520601296543433e-03 -4.478070067533340692e-03 -4.394683688871586433e-03 +1.400000000000000133e+00 -5.835353622834494637e-03 -5.097530655625692915e-03 -4.389691198859401421e-03 +1.500000000000000222e+00 -5.863290690264541874e-03 -5.215500241204417201e-03 -4.380686516072217034e-03 +1.600000000000000089e+00 -6.007605076700822840e-03 -5.234994618743306349e-03 -4.367337507268855175e-03 +1.700000000000000178e+00 -6.481613230242359684e-03 -5.228094160806716871e-03 -4.348706108547779198e-03 +1.800000000000000266e+00 -6.814114687600298335e-03 -5.208252365588400719e-03 -4.323505520547227775e-03 +1.900000000000000133e+00 -6.876286379079538276e-03 -5.177988357772074675e-03 -4.290186895355558444e-03 +2.000000000000000000e+00 -6.858440816799354217e-03 -5.136887568332395605e-03 -4.246989919717190920e-03 +2.100000000000000089e+00 -6.810730159155128395e-03 -5.083475665301987606e-03 -4.192000168715152505e-03 +2.200000000000000178e+00 -6.742330737387775344e-03 -5.015815334399144516e-03 -4.123231519970332187e-03 +2.300000000000000266e+00 -6.653841351238824232e-03 -4.931782661310191510e-03 -4.038743210125123918e-03 +2.400000000000000355e+00 -6.543651317938833402e-03 -4.829269294496830317e-03 -3.936795390727530070e-03 +2.500000000000000444e+00 -6.409559281498313811e-03 -4.706385522261587705e-03 -3.816040239463167755e-03 +2.600000000000000089e+00 -6.249406635892575460e-03 -4.561685215972477100e-03 -3.675736338668155346e-03 +2.700000000000000178e+00 -6.061478463281754457e-03 -4.394408172892586353e-03 -3.515962176363645990e-03 +2.800000000000000266e+00 -5.844844934626365965e-03 -4.204716954930251029e-03 -3.337792190764940319e-03 +2.900000000000000355e+00 -5.599669004675433479e-03 -3.993889719587391009e-03 -3.143390268473208755e-03 +3.000000000000000444e+00 -5.327453506642119106e-03 -3.764420755089863558e-03 -2.935977648106832729e-03 +3.100000000000000089e+00 -5.031178000843260223e-03 -3.519982860915751074e-03 -2.719650568099894056e-03 +3.200000000000000178e+00 -4.715273672783852794e-03 -3.265225882759082918e-03 -2.499057451653833965e-03 +3.300000000000000266e+00 -4.385404785641488362e-03 -3.005422601424333727e-03 -2.278985743812388717e-03 +3.400000000000000355e+00 -4.048065433713449700e-03 -2.746015696661484231e-03 -2.063937321866260270e-03 +3.500000000000000444e+00 -3.710048572169818114e-03 -2.492149763588673555e-03 -1.857774171128685628e-03 +3.600000000000000089e+00 -3.377881092113224713e-03 -2.248275746149775312e-03 -1.663491260531681313e-03 +3.700000000000000178e+00 -3.057327225182689644e-03 -2.017890114824574810e-03 -1.483133951195727196e-03 +3.800000000000000266e+00 -2.753038981057491941e-03 -1.803430168074075671e-03 -1.317840750738439540e-03 +3.900000000000000355e+00 -2.468388171389931940e-03 -1.606308000309067743e-03 -1.167971059502070875e-03 +4.000000000000000000e+00 -2.205469013267805957e-03 -1.427041871266797194e-03 -1.033273795673775699e-03 +4.099999999999999645e+00 -1.965228953751702902e-03 -1.265437879541002862e-03 -9.130610310879381641e-04 +4.200000000000000178e+00 -1.747673832278765806e-03 -1.120782158543769547e-03 -8.063636493380576522e-04 +4.299999999999999822e+00 -1.552098284175109895e-03 -9.920168984562682292e-04 -7.120580835032176920e-04 +4.399999999999999467e+00 -1.377305748647780163e-03 -8.778864597897169646e-04 -6.289618864203703032e-04 +4.500000000000000000e+00 -1.221797526507303194e-03 -7.770496638083513111e-04 -5.559009474092405914e-04 +4.599999999999999645e+00 -1.083922782809847944e-03 -6.881603844395511003e-04 -4.917533939693695443e-04 +4.700000000000000178e+00 -9.619897379282633162e-04 -6.099214740721333600e-04 -4.354756390957214944e-04 +4.799999999999999822e+00 -8.543428352989788704e-04 -5.411178648690499965e-04 -3.861155118068372257e-04 +4.900000000000000355e+00 -7.594124385866309881e-04 -4.806343247547230249e-04 -3.428165131289927659e-04 +5.000000000000000000e+00 -6.757436744162991990e-04 -4.274624687438948085e-04 -3.048162971647301774e-04 +5.099999999999999645e+00 -6.020102408497160842e-04 -3.807006248475114439e-04 -2.714416410742632600e-04 +5.200000000000000178e+00 -5.370178955485286568e-04 -3.395492294862310413e-04 -2.421014916366724180e-04 +5.299999999999999822e+00 -4.797012289428498875e-04 -3.033036596191310643e-04 -2.162791601488694472e-04 +5.400000000000000355e+00 -4.291163603974148220e-04 -2.713458112672340397e-04 -1.935243599692976007e-04 +5.500000000000000000e+00 -3.844314156775488251e-04 -2.431352896687036106e-04 -1.734455139070628909e-04 +5.599999999999999645e+00 -3.449160478270333653e-04 -2.182007570692257958e-04 -1.557025751017268144e-04 +5.700000000000000178e+00 -3.099308250478081581e-04 -1.961317615550248216e-04 -1.400004825033046053e-04 +5.799999999999999822e+00 -2.789169965744946232e-04 -1.765712194195623135e-04 -1.260832928664179986e-04 +5.900000000000000355e+00 -2.513869308376957498e-04 -1.592086242469989389e-04 -1.137289820430557137e-04 +6.000000000000000000e+00 -2.269153740910770769e-04 -1.437739928526920498e-04 -1.027448796326863424e-04 +6.099999999999999645e+00 -2.051315821421489645e-04 -1.300325201124615134e-04 -9.296368585686617101e-05 +6.200000000000000178e+00 -1.857123177371916057e-04 -1.177798933810950252e-04 -8.424001307773229020e-05 +6.299999999999999822e+00 -1.683756703844696025e-04 -1.068382068924710331e-04 -7.644739339328402133e-05 +6.400000000000000355e+00 -1.528756359693242027e-04 -9.705241326038571551e-05 -6.947569600606938272e-05 +6.500000000000000000e+00 -1.389973847900246836e-04 -8.828725024164000819e-05 -6.322890211399061677e-05 +6.599999999999999645e+00 -1.265531447216864910e-04 -8.042458445906290783e-05 -5.762318996708282935e-05 +6.700000000000000178e+00 -1.153786284350462083e-04 -7.336111861455703732e-05 -5.258528787829301387e-05 +6.799999999999999822e+00 -1.053299381724164837e-04 -6.700641408290249014e-05 -4.805105801105378322e-05 +6.900000000000000355e+00 -9.628088734156424651e-05 -6.128118618925484863e-05 -4.396427848897285724e-05 +7.000000000000000000e+00 -8.812068437769617318e-05 -5.611583465513190913e-05 -4.027559568117881135e-05 +7.099999999999999645e+00 -8.075193047879847589e-05 -5.144917649553730292e-05 -3.694162237238345173e-05 +7.200000000000000178e+00 -7.408888866698059216e-05 -4.722735299269381154e-05 -3.392416093003276399e-05 +7.299999999999999822e+00 -6.805598702152939358e-05 -4.340288624040664528e-05 -3.118953355495664799e-05 +7.400000000000000355e+00 -6.258652380321327402e-05 -3.993386415929820209e-05 -2.870800428147100793e-05 +7.500000000000000000e+00 -5.762154653038724025e-05 -3.678323585657680581e-05 -2.645327961777798337e-05 +7.599999999999999645e+00 -5.310888089285451013e-05 -3.391820178292012157e-05 -2.440207662848582849e-05 +7.700000000000000178e+00 -4.900228873380196631e-05 -3.130968536501008690e-05 -2.253374889733842244e-05 +7.799999999999999822e+00 -4.526073723647751752e-05 -2.893187470658392955e-05 -2.082996220611386151e-05 +7.900000000000000355e+00 -4.184776396661089387e-05 -2.676182459276817884e-05 -1.927441295794754013e-05 +8.000000000000000000e+00 -3.873092458939377268e-05 -2.477911043795883125e-05 -1.785258338919504348e-05 +8.099999999999999645e+00 -3.588131194417033489e-05 -2.296552701898519263e-05 -1.655152847892165657e-05 +8.199999999999999289e+00 -3.327313676038550535e-05 -2.130482586144845337e-05 -1.535969020138178571e-05 +8.300000000000000711e+00 -3.088336167038252842e-05 -1.978248602307972753e-05 -1.426673539356727330e-05 +8.400000000000000355e+00 -2.869138134992016182e-05 -1.838551376555334571e-05 -1.326341404347894562e-05 +8.500000000000000000e+00 -2.667874262351516647e-05 -1.710226724425689568e-05 -1.234143525923654349e-05 +8.599999999999999645e+00 -2.482889923305170253e-05 -1.592230289025118773e-05 -1.149335856644029766e-05 +8.699999999999999289e+00 -2.312699670543422217e-05 -1.483624062390156494e-05 -1.071249851148625846e-05 +8.800000000000000711e+00 -2.155968338640409072e-05 -1.383564543725925445e-05 -9.992840830439041262e-06 +8.900000000000000355e+00 -2.011494424844594071e-05 -1.291292322228744002e-05 -9.328968683880578804e-06 +9.000000000000000000e+00 -1.878195454422273937e-05 -1.206122901303710759e-05 -8.715997664073560640e-06 +9.099999999999999645e+00 -1.755095077450199208e-05 -1.127438605916122122e-05 -8.149518457037649769e-06 +9.199999999999999289e+00 -1.641311678073074286e-05 -1.054681436190059537e-05 -7.625546193173260002e-06 +9.300000000000000711e+00 -1.536048306550537418e-05 -9.873467487129955082e-06 -7.140475649634374041e-06 +9.400000000000000355e+00 -1.438583769617946587e-05 -9.249776627676899564e-06 -6.691041578929334717e-06 +9.500000000000000000e+00 -1.348264736372320039e-05 -8.671601022702781346e-06 -6.274283533910064576e-06 +9.599999999999999645e+00 -1.264498735578012246e-05 -8.135183958678718279e-06 -5.887514641681122086e-06 +9.700000000000001066e+00 -1.186747936398473687e-05 -7.637113677130612127e-06 -5.528293849956352819e-06 +9.800000000000000711e+00 -1.114523618469756001e-05 -7.174288601187318493e-06 -5.194401230658985063e-06 +9.900000000000000355e+00 -1.047381249252528874e-05 -6.743886368019750717e-06 -4.883815978498405921e-06 +1.000000000000000000e+01 0.000000000000000e00e+00 0.000000000000000e00e+00 0.000000000000000e00e+00 diff --git a/examples/water/d3/input.json b/examples/water/d3/input.json new file mode 100644 index 0000000000..bbe7a2c8a9 --- /dev/null +++ b/examples/water/d3/input.json @@ -0,0 +1,95 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "linear_ener", + "weights": "sum", + "models": [ + { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "se_e2_a", + "sel": [ + 46, + 92 + ], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 25, + 50, + 100 + ], + "resnet_dt": false, + "axis_neuron": 16, + "precision": "float64", + "seed": 1, + "_comment2": " that's all" + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float64", + "seed": 1, + "_comment3": " that's all" + }, + "_comment4": " that's all" + }, + { + "type": "pairtab", + "tab_file": "dftd3.txt", + "rcut": 10.0, + "sel": 534 + } + ] + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment8": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment9": "that's all" + }, + "_comment10": "that's all" +} diff --git a/source/tests/test_model_pairtab.py b/source/tests/test_model_pairtab.py new file mode 100644 index 0000000000..076005ba6c --- /dev/null +++ b/source/tests/test_model_pairtab.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import scipy.spatial.distance +from common import ( + DataSystem, + gen_data, + j_loader, +) + +from deepmd.common import ( + j_must_have, +) +from deepmd.env import ( + tf, +) +from deepmd.model.model import ( + Model, +) + +GLOBAL_ENER_FLOAT_PRECISION = tf.float64 +GLOBAL_TF_FLOAT_PRECISION = tf.float64 +GLOBAL_NP_FLOAT_PRECISION = np.float64 + + +class TestModel(tf.test.TestCase): + def setUp(self): + gen_data() + + def test_model(self): + jfile = "water.json" + jdata = j_loader(jfile) + systems = j_must_have(jdata, "systems") + set_pfx = j_must_have(jdata, "set_prefix") + batch_size = j_must_have(jdata, "batch_size") + test_size = j_must_have(jdata, "numb_test") + batch_size = 1 + test_size = 1 + + tab_filename = "test_pairtab_tab.txt" + jdata["model"] = { + "type": "pairtab", + "tab_file": tab_filename, + "rcut": 6, + "sel": [6], + } + rcut = j_must_have(jdata["model"], "rcut") + + def pair_pot(r: float): + # LJ, as exmaple + return 4 * (1 / r**12 - 1 / r**6) + + dx = 1e-4 + d = np.arange(dx, rcut + dx, dx) + tab = np.array( + [ + d, + pair_pot(d), + pair_pot(d), + pair_pot(d), + ] + ).T + np.savetxt(tab_filename, tab) + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) + + test_data = data.get_test() + numb_test = 1 + + model = Model( + **jdata["model"], + ) + + t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") + t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy") + t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force") + t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial") + t_atom_ener = tf.placeholder( + GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener" + ) + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") + t_type = tf.placeholder(tf.int32, [None], name="i_type") + t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") + t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") + is_training = tf.placeholder(tf.bool) + t_fparam = None + + model_pred = model.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + t_fparam, + suffix="test_pairtab", + reuse=False, + ) + energy = model_pred["energy"] + force = model_pred["force"] + virial = model_pred["virial"] + + feed_dict_test = { + t_prop_c: test_data["prop_c"], + t_energy: test_data["energy"][:numb_test], + t_force: np.reshape(test_data["force"][:numb_test, :], [-1]), + t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]), + t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]), + t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), + t_box: test_data["box"][:numb_test, :], + t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), + t_natoms: test_data["natoms_vec"], + t_mesh: [], # nopbc + is_training: False, + } + + with self.cached_session() as sess: + sess.run(tf.global_variables_initializer()) + [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) + + e = e.reshape([-1]) + f = f.reshape([-1]) + v = v.reshape([-1]) + + coord = test_data["coord"][0, :].reshape(-1, 3) + distance = scipy.spatial.distance.cdist(coord, coord).ravel() + refe = [np.sum(pair_pot(distance[np.nonzero(distance)])) / 2] + + refe = np.reshape(refe, [-1]) + + places = 10 + np.testing.assert_almost_equal(e, refe, places)