diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index b13bfc17ba..bdff512311 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -80,6 +80,10 @@ def mixed_types(self) -> bool: """ return self.descriptor.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return self.descriptor.has_message_passing() + def forward_atomic( self, extended_coord: np.ndarray, diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 7dff9078c5..07cb6b560e 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -92,6 +92,10 @@ def mixed_types(self) -> bool: """ return True + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return any(model.has_message_passing() for model in self.models) + def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index 936c2b0943..2b47cd81e6 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -115,6 +115,10 @@ def mixed_types(self) -> bool: """ pass + @abstractmethod + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + @abstractmethod def fwd( self, diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index d3d179e6e2..4d9097a0e9 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -131,6 +131,10 @@ def mixed_types(self) -> bool: # to match DPA1 and DPA2. return True + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return False + def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index 444df1abf8..efd804496a 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -125,3 +125,7 @@ def call( ): """Calculate DescriptorBlock.""" pass + + @abstractmethod + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 8abc8c2777..d30dad5c10 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -349,6 +349,10 @@ def mixed_types(self) -> bool: """ return self.se_atten.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.se_atten.has_message_passing() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.se_atten.get_env_protection() @@ -885,6 +889,10 @@ def call( sw, ) + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False + class NeighborGatedAttention(NativeOP): def __init__( diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 96870c9bd7..f3e88ddacc 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -524,6 +524,12 @@ def mixed_types(self) -> bool: """ return True + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return any( + [self.repinit.has_message_passing(), self.repformers.has_message_passing()] + ) + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index d359bf6fb7..6912590317 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -138,6 +138,10 @@ def mixed_types(self): """ return any(descrpt.mixed_types() for descrpt in self.descrpt_list) + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return any(descrpt.has_message_passing() for descrpt in self.descrpt_list) + def get_env_protection(self) -> float: """Returns the protection of building environment matrix. All descriptors should be the same.""" all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list] diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index a4fc8bddf9..328352c7d8 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -95,6 +95,10 @@ def mixed_types(self) -> bool: """ pass + @abstractmethod + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + @abstractmethod def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index c9ac48efec..3d2917cdf6 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -387,6 +387,10 @@ def call( rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True + # translated by GPT and modified def get_residual( diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 7a8899b4e5..b4c6fe564c 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -247,6 +247,10 @@ def mixed_types(self): """ return False + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index dbb6e104fb..e0a5621e41 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -205,6 +205,10 @@ def mixed_types(self): """ return False + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index cdbeb701ce..90e76115b4 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -185,6 +185,10 @@ def mixed_types(self): """ return False + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 7993f10abd..f8579de9a4 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -471,6 +471,10 @@ def mixed_types(self) -> bool: """ return self.atomic_model.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the model has message passing.""" + return self.atomic_model.has_message_passing() + def atomic_output_def(self) -> FittingOutputDef: """Get the output def of the atomic model.""" return self.atomic_model.atomic_output_def() diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index ab9992b6b5..8e37dbf09b 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -303,10 +303,7 @@ def train(FLAGS): def freeze(FLAGS): model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model) - if '"type": "dpa2"' in model.get_model_def_script(): - extra_files = {"type": "dpa2"} - else: - extra_files = {"type": "else"} + extra_files = {} torch.jit.save( model, FLAGS.output, diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 3d9a57bf70..90254e8c11 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -95,6 +95,10 @@ def mixed_types(self) -> bool: """ return self.descriptor.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return self.descriptor.has_message_passing() + def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 8b77f0c7c5..db8280cd02 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -104,6 +104,10 @@ def mixed_types(self) -> bool: """ return True + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return any(model.has_message_passing() for model in self.models) + def get_out_bias(self) -> torch.Tensor: return self.out_bias diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 3a0700be4f..ff1a83da6a 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -160,6 +160,10 @@ def mixed_types(self) -> bool: # to match DPA1 and DPA2. return True + def has_message_passing(self) -> bool: + """Returns whether the atomic model has message passing.""" + return False + def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 5e0cdac72b..28656d716c 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -169,6 +169,10 @@ def forward( """Calculate DescriptorBlock.""" pass + @abstractmethod + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + def make_default_type_embedding( ntypes, diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 92f5cf2e15..8f19aad961 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -342,6 +342,10 @@ def mixed_types(self) -> bool: """ return self.se_atten.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.se_atten.has_message_passing() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.se_atten.get_env_protection() diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index b33a528721..322c34734a 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -294,6 +294,12 @@ def mixed_types(self) -> bool: """ return True + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return any( + [self.repinit.has_message_passing(), self.repformers.has_message_passing()] + ) + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" # the env_protection of repinit is the same as that of the repformer diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index e202005f4c..3733cec8e7 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -143,6 +143,10 @@ def mixed_types(self): """ return any(descrpt.mixed_types() for descrpt in self.descrpt_list) + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return any(descrpt.has_message_passing() for descrpt in self.descrpt_list) + def get_env_protection(self) -> float: """Returns the protection of building environment matrix. All descriptors should be the same.""" all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list] diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index a66693653e..6d6bb9a54b 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -548,3 +548,7 @@ def get_stats(self) -> Dict[str, StatItem]: "The statistics of the descriptor has not been computed." ) return self.stats + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return True diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 350fceae2d..01a6d1ab38 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -147,6 +147,10 @@ def mixed_types(self): """ return self.sea.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.sea.has_message_passing() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.sea.get_env_protection() @@ -674,3 +678,7 @@ def forward( None, sw, ) + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index a59eaca409..a38051ce68 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -581,6 +581,10 @@ def forward( sw, ) + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False + class NeighborGatedAttention(nn.Module): def __init__( diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index b0a739f5e6..21fecd4857 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -170,6 +170,10 @@ def mixed_types(self) -> bool: """ return False + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 2c8f52709f..3b67e1657f 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -174,6 +174,10 @@ def mixed_types(self): """ return self.seat.mixed_types() + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.seat.has_message_passing() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.seat.get_env_protection() @@ -687,3 +691,7 @@ def forward( None, sw, ) + + def has_message_passing(self) -> bool: + """Returns whether the descriptor block has message passing.""" + return False diff --git a/deepmd/pt/model/model/frozen.py b/deepmd/pt/model/model/frozen.py index 148ffaa703..79bc450333 100644 --- a/deepmd/pt/model/model/frozen.py +++ b/deepmd/pt/model/model/frozen.py @@ -106,6 +106,11 @@ def mixed_types(self) -> bool: """ return self.model.mixed_types() + @torch.jit.export + def has_message_passing(self) -> bool: + """Returns whether the descriptor has message passing.""" + return self.model.has_message_passing() + @torch.jit.export def forward( self, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 989789c201..31e26dc718 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -531,6 +531,11 @@ def mixed_types(self) -> bool: """ return self.atomic_model.mixed_types() + @torch.jit.export + def has_message_passing(self) -> bool: + """Returns whether the model has message passing.""" + return self.atomic_model.has_message_passing() + def forward( self, coord, diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 232c9efd31..ea41bb32f6 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -59,13 +59,7 @@ void DeepPotPT::init(const std::string& model, } std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); - // TODO: This should be fixed after implement api to decide whether need to - // message passing and rename this metadata - if (metadata["type"] == "dpa2") { - do_message_passing = 1; - } else { - do_message_passing = 0; - } + do_message_passing = module.run_method("has_message_passing").toBool(); torch::jit::FusionStrategy strategy; strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}}; torch::jit::setFusionStrategy(strategy); diff --git a/source/api_cc/tests/test_deeppot_dpa_pt.cc b/source/api_cc/tests/test_deeppot_dpa_pt.cc index 416802cd20..94cadeaf19 100644 --- a/source/api_cc/tests/test_deeppot_dpa_pt.cc +++ b/source/api_cc/tests/test_deeppot_dpa_pt.cc @@ -25,30 +25,51 @@ class TestInferDeepPotDpaPt : public ::testing::Test { 3.51, 2.51, 2.60, 4.27, 3.22, 1.56}; std::vector atype = {0, 1, 1, 0, 1, 1}; std::vector box = {13., 0., 0., 0., 13., 0., 0., 0., 13.}; - std::vector expected_e = {-93.295296030283, -186.548183879333, - -186.988827037855, -93.295307298571, - -186.799369383945, -186.507754447584}; + // Generated by the following Python code: + // import numpy as np + // from deepmd.infer import DeepPot + // coord = np.array([ + // 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + // 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + // 3.51, 2.51, 2.60, 4.27, 3.22, 1.56 + // ]).reshape(1, -1) + // atype = np.array([0, 1, 1, 0, 1, 1]) + // box = np.array([13., 0., 0., 0., 13., 0., 0., 0., 13.]).reshape(1, -1) + // dp = DeepPot("deeppot_dpa.pth") + // e, v, f, ae, av = dp.eval(coord, box, atype, atomic=True) + // np.set_printoptions(precision=16) + // print(f"{e.ravel()=} {v.ravel()=} {f.ravel()=} {ae.ravel()=} + // {av.ravel()=}") + + std::vector expected_e = { + -94.37720733019096, -187.43155959873033, -187.37830241580824, + -94.34880710985752, -187.38869830422271, -187.33919952642458}; std::vector expected_f = { - 4.964133039248, -0.542378158452, -0.381267990914, -0.563388054735, - 0.340320322541, 0.473406268590, 0.159774831398, 0.684651816874, - -0.377008867620, -4.718603033927, -0.012604322920, -0.425121993870, - -0.500302936762, -0.637586419292, 0.930351899011, 0.658386154778, - 0.167596761250, -0.220359315197}; + 5.402355596838843, -1.263284191331685, -0.697693239979719, + -1.025144852453706, 0.6554396369933394, 0.8817286288078215, + 0.4364579972147229, 1.2150079148857598, -0.6778076371985796, + -6.939243547937094, 0.1571084862688049, -0.9017435514431825, + 0.3597967524845581, -1.328808718007412, 2.0974306454214653, + 1.7657780538526762, 0.5645368711911929, -0.7019148456078053}; std::vector expected_v = { - -5.055176133632, -0.743392222876, 0.330846378467, -0.031111229868, - 0.018004461517, 0.170047655301, -0.063087726831, -0.004361215202, - -0.042920299661, 3.624188578021, -0.252818122305, -0.026516806138, - -0.014510755893, 0.103726553937, 0.181001311123, -0.508673535094, - 0.142101134395, 0.135339636607, -0.460067993361, 0.120541583338, - -0.206396390140, -0.630991740522, 0.397670086144, -0.427022150075, - 0.656463775044, -0.209989614377, 0.288974239790, -7.603428707029, - -0.912313971544, 0.882084544041, -0.807760666057, -0.070519570327, - 0.022164414763, 0.569448616709, 0.028522950109, 0.051641619288, - -1.452133900157, 0.037653156584, -0.144421326931, -0.308825789350, - 0.302020522568, -0.446073217801, 0.313539058423, -0.461052923736, - 0.678235442273, 1.429780276456, 0.080472825760, -0.103424652500, - 0.123343430648, 0.011879908277, -0.018897229721, -0.235518441452, - -0.013999547600, 0.027007016662}; + 9.5175137906314511e-01, -2.0801835688892991e+00, 4.6860789988973117e-01, + -6.0178723966859824e+00, 1.2556002911926123e-01, 4.7887097832213565e-02, + 5.6216590124464116e-01, 1.7071246159044051e-01, 8.4990129293690209e-02, + -1.2558035496847255e+00, -3.1123763096053136e-02, -4.4100135935181761e-01, + 6.4707184007995455e-01, 1.5574441384822924e-01, 3.2409058144551339e-01, + 2.8631311270672963e+00, -3.0375434485037031e-04, 3.9533024424985619e-01, + 3.2722174727830535e+00, 1.1867224518409690e-01, -2.2250901443705223e-01, + 5.0337980348311300e+00, 6.0517723355290898e-01, -5.5204995585567707e-01, + -3.8335680797875722e+00, -2.3083403461022087e-01, 3.1281970616476651e-01, + -1.0733902445454071e+01, -2.7634498084191517e-01, 1.5720135955951031e+00, + -2.9262906180354680e+00, 1.0845127764896278e-01, -1.1142053272645919e-01, + 3.6066832583682209e+00, -1.9002351752094526e-01, 3.1875602887687587e-01, + 3.6971839777382898e-01, -2.7352380159430506e-02, 1.0670299036230046e-01, + 1.8155828042674422e+00, 4.9170982983933986e-01, -6.7166291183351579e-01, + -2.9003369690467395e+00, -7.6647630459927585e-01, 1.0566933380800889e+00, + -4.8620953903555858e-01, 4.0440213825136057e-01, -6.5227187264812003e-01, + -4.4421997400831864e-01, 1.4811202361724179e-01, -2.4354470120979710e-01, + 5.3346700156430571e-01, -1.8977527286286849e-01, 3.1383559345422440e-01}; int natoms; double expected_tot_e; std::vector expected_tot_v; diff --git a/source/lmp/tests/test_lammps_dpa_pt.py b/source/lmp/tests/test_lammps_dpa_pt.py index a4e2f93014..0c1d46c5f7 100644 --- a/source/lmp/tests/test_lammps_dpa_pt.py +++ b/source/lmp/tests/test_lammps_dpa_pt.py @@ -33,35 +33,35 @@ # this is as the same as python and c++ tests, test_deeppot_a.py expected_ae = np.array( [ - -93.295296030283, - -186.548183879333, - -186.988827037855, - -93.295307298571, - -186.799369383945, - -186.507754447584, + -94.37720733019096, + -187.43155959873033, + -187.37830241580824, + -94.34880710985752, + -187.38869830422271, + -187.33919952642458, ] ) expected_e = np.sum(expected_ae) expected_f = np.array( [ - 4.964133039248, - -0.542378158452, - -0.381267990914, - -0.563388054735, - 0.340320322541, - 0.473406268590, - 0.159774831398, - 0.684651816874, - -0.377008867620, - -4.718603033927, - -0.012604322920, - -0.425121993870, - -0.500302936762, - -0.637586419292, - 0.930351899011, - 0.658386154778, - 0.167596761250, - -0.220359315197, + 5.402355596838843, + -1.263284191331685, + -0.697693239979719, + -1.025144852453706, + 0.6554396369933394, + 0.8817286288078215, + 0.4364579972147229, + 1.2150079148857598, + -0.6778076371985796, + -6.939243547937094, + 0.1571084862688049, + -0.9017435514431825, + 0.3597967524845581, + -1.328808718007412, + 2.0974306454214653, + 1.7657780538526762, + 0.5645368711911929, + -0.7019148456078053, ] ).reshape(6, 3) @@ -78,60 +78,60 @@ expected_v = -np.array( [ - -5.055176133632, - -0.743392222876, - 0.330846378467, - -0.031111229868, - 0.018004461517, - 0.170047655301, - -0.063087726831, - -0.004361215202, - -0.042920299661, - 3.624188578021, - -0.252818122305, - -0.026516806138, - -0.014510755893, - 0.103726553937, - 0.181001311123, - -0.508673535094, - 0.142101134395, - 0.135339636607, - -0.460067993361, - 0.120541583338, - -0.206396390140, - -0.630991740522, - 0.397670086144, - -0.427022150075, - 0.656463775044, - -0.209989614377, - 0.288974239790, - -7.603428707029, - -0.912313971544, - 0.882084544041, - -0.807760666057, - -0.070519570327, - 0.022164414763, - 0.569448616709, - 0.028522950109, - 0.051641619288, - -1.452133900157, - 0.037653156584, - -0.144421326931, - -0.308825789350, - 0.302020522568, - -0.446073217801, - 0.313539058423, - -0.461052923736, - 0.678235442273, - 1.429780276456, - 0.080472825760, - -0.103424652500, - 0.123343430648, - 0.011879908277, - -0.018897229721, - -0.235518441452, - -0.013999547600, - 0.027007016662, + 9.5175137906314511e-01, + -2.0801835688892991e00, + 4.6860789988973117e-01, + -6.0178723966859824e00, + 1.2556002911926123e-01, + 4.7887097832213565e-02, + 5.6216590124464116e-01, + 1.7071246159044051e-01, + 8.4990129293690209e-02, + -1.2558035496847255e00, + -3.1123763096053136e-02, + -4.4100135935181761e-01, + 6.4707184007995455e-01, + 1.5574441384822924e-01, + 3.2409058144551339e-01, + 2.8631311270672963e00, + -3.0375434485037031e-04, + 3.9533024424985619e-01, + 3.2722174727830535e00, + 1.1867224518409690e-01, + -2.2250901443705223e-01, + 5.0337980348311300e00, + 6.0517723355290898e-01, + -5.5204995585567707e-01, + -3.8335680797875722e00, + -2.3083403461022087e-01, + 3.1281970616476651e-01, + -1.0733902445454071e01, + -2.7634498084191517e-01, + 1.5720135955951031e00, + -2.9262906180354680e00, + 1.0845127764896278e-01, + -1.1142053272645919e-01, + 3.6066832583682209e00, + -1.9002351752094526e-01, + 3.1875602887687587e-01, + 3.6971839777382898e-01, + -2.7352380159430506e-02, + 1.0670299036230046e-01, + 1.8155828042674422e00, + 4.9170982983933986e-01, + -6.7166291183351579e-01, + -2.9003369690467395e00, + -7.6647630459927585e-01, + 1.0566933380800889e00, + -4.8620953903555858e-01, + 4.0440213825136057e-01, + -6.5227187264812003e-01, + -4.4421997400831864e-01, + 1.4811202361724179e-01, + -2.4354470120979710e-01, + 5.3346700156430571e-01, + -1.8977527286286849e-01, + 3.1383559345422440e-01, ] ).reshape(6, 9) expected_v2 = -np.array( diff --git a/source/tests/infer/deeppot_dpa.pth b/source/tests/infer/deeppot_dpa.pth index d54a1c1779..e7bffdae33 100644 Binary files a/source/tests/infer/deeppot_dpa.pth and b/source/tests/infer/deeppot_dpa.pth differ diff --git a/source/tests/infer/deeppot_sea.pth b/source/tests/infer/deeppot_sea.pth index c830f0df9e..19222ab4df 100644 Binary files a/source/tests/infer/deeppot_sea.pth and b/source/tests/infer/deeppot_sea.pth differ diff --git a/source/tests/infer/fparam_aparam.pth b/source/tests/infer/fparam_aparam.pth index 703f7267be..65fc6ef476 100644 Binary files a/source/tests/infer/fparam_aparam.pth and b/source/tests/infer/fparam_aparam.pth differ diff --git a/source/tests/universal/common/cases/atomic_model/ener_model.py b/source/tests/universal/common/cases/atomic_model/ener_model.py index 0f1daaf87b..b1e2f75cc7 100644 --- a/source/tests/universal/common/cases/atomic_model/ener_model.py +++ b/source/tests/universal/common/cases/atomic_model/ener_model.py @@ -16,3 +16,4 @@ def setUp(self) -> None: self.expected_aparam_nall = False self.expected_model_output_type = ["energy", "mask"] self.expected_sel = [8, 12] + self.expected_has_message_passing = False diff --git a/source/tests/universal/common/cases/atomic_model/utils.py b/source/tests/universal/common/cases/atomic_model/utils.py index 3b5fc64fda..39e38cdfdd 100644 --- a/source/tests/universal/common/cases/atomic_model/utils.py +++ b/source/tests/universal/common/cases/atomic_model/utils.py @@ -31,6 +31,8 @@ class AtomicModelTestCase: """Expected output type for the model.""" expected_sel: List[int] """Expected number of neighbors.""" + expected_has_message_passing: bool + """Expected whether having message passing.""" forward_wrapper: Callable[[Any], Any] """Calss wrapper for forward method.""" diff --git a/source/tests/universal/common/cases/model/ener_model.py b/source/tests/universal/common/cases/model/ener_model.py index 35d44f9784..54fc19073f 100644 --- a/source/tests/universal/common/cases/model/ener_model.py +++ b/source/tests/universal/common/cases/model/ener_model.py @@ -16,3 +16,4 @@ def setUp(self) -> None: self.expected_aparam_nall = False self.expected_model_output_type = ["energy", "mask"] self.expected_sel = [8, 12] + self.expected_has_message_passing = False diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index d67ac8e80d..30f9da3b14 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -31,6 +31,8 @@ class ModelTestCase: """Expected output type for the model.""" expected_sel: List[int] """Expected number of neighbors.""" + expected_has_message_passing: bool + """Expected whether having message passing.""" forward_wrapper: Callable[[Any], Any] """Calss wrapper for forward method.""" @@ -82,6 +84,13 @@ def test_get_ntypes(self): for module in self.modules_to_test: self.assertEqual(module.get_ntypes(), len(self.expected_type_map)) + def test_has_message_passing(self): + """Test has_message_passing.""" + for module in self.modules_to_test: + self.assertEqual( + module.has_message_passing(), self.expected_has_message_passing + ) + def test_forward(self): """Test forward and forward_lower.""" nf = 1