From 5d9b94cf5c422e592fc1b1dc572a1cf06f11e9f1 Mon Sep 17 00:00:00 2001 From: Lysithea Date: Mon, 3 Jun 2024 15:52:43 +0800 Subject: [PATCH] fix message passing api --- deepmd/pt/entrypoints/main.py | 6 +++--- deepmd/pt/model/atomic_model/dp_atomic_model.py | 8 ++++++++ deepmd/pt/model/descriptor/dpa1.py | 12 ++++++++++-- deepmd/pt/model/descriptor/dpa2.py | 12 ++++++++++-- deepmd/pt/model/descriptor/hybrid.py | 6 +++++- deepmd/pt/model/descriptor/repformers.py | 6 ++++++ deepmd/pt/model/descriptor/se_a.py | 12 ++++++++++++ deepmd/pt/model/descriptor/se_atten.py | 8 ++++++++ deepmd/pt/model/descriptor/se_r.py | 6 ++++++ deepmd/pt/model/descriptor/se_t.py | 12 ++++++++++++ deepmd/pt/model/model/make_model.py | 7 +++++++ deepmd/pt/model/model/model.py | 4 ++++ source/api_cc/src/DeepPotPT.cc | 4 +--- 13 files changed, 92 insertions(+), 11 deletions(-) diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index ab9992b6b5..e62e8626ef 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -303,10 +303,10 @@ def train(FLAGS): def freeze(FLAGS): model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model) - if '"type": "dpa2"' in model.get_model_def_script(): - extra_files = {"type": "dpa2"} + if model.has_message_passing() == True: + extra_files = {"has_message_passing": "True"} else: - extra_files = {"type": "else"} + extra_files = {"has_message_passing": "False"} torch.jit.save( model, FLAGS.output, diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 3d9a57bf70..e7af7b7715 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -94,6 +94,14 @@ def mixed_types(self) -> bool: """ return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + """ + If true, the model has a structure of message-passing network, which requires communication op in multi-process inference. + + If false, the op above is not needed. + """ + return self.descriptor.has_message_passing() def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 92f5cf2e15..d5aed8ebfe 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -331,17 +331,25 @@ def get_dim_emb(self) -> int: return self.se_atten.dim_emb def mixed_types(self) -> bool: - """If true, the discriptor + """If true, the descriptor 1. assumes total number of atoms aligned across frames; 2. requires a neighbor list that does not distinguish different atomic types. - If false, the discriptor + If false, the descriptor 1. assumes total number of atoms of each atom type aligned across frames; 2. requires a neighbor list that distinguishes different atomic types. """ return self.se_atten.mixed_types() + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + + If false, the op above is not needed. + """ + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.se_atten.get_env_protection() diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index b33a528721..829434c795 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -283,17 +283,25 @@ def get_dim_emb(self) -> int: return self.repformers.dim_emb def mixed_types(self) -> bool: - """If true, the discriptor + """If true, the descriptor 1. assumes total number of atoms aligned across frames; 2. requires a neighbor list that does not distinguish different atomic types. - If false, the discriptor + If false, the descriptor 1. assumes total number of atoms of each atom type aligned across frames; 2. requires a neighbor list that distinguishes different atomic types. """ return True + + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return True + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" # the env_protection of repinit is the same as that of the repformer diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index e202005f4c..1ee2aed5b2 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -142,7 +142,11 @@ def mixed_types(self): atomic types or not. """ return any(descrpt.mixed_types() for descrpt in self.descrpt_list) - + def has_message_passing(self) -> bool: + """ + Returns if the descriptor requires message passing or not. + """ + return any(descrpt.has_message_passing() for descrpt in self.descrpt_list) def get_env_protection(self) -> float: """Returns the protection of building environment matrix. All descriptors should be the same.""" all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list] diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index a66693653e..02755eb5bc 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -366,7 +366,13 @@ def mixed_types(self) -> bool: """ return True + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return True def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 350fceae2d..fdaecb70db 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -146,7 +146,13 @@ def mixed_types(self): atomic types or not. """ return self.sea.mixed_types() + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return False def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.sea.get_env_protection() @@ -478,7 +484,13 @@ def mixed_types(self) -> bool: """ return False + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return False def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index a59eaca409..0627c60ea3 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -349,6 +349,14 @@ def mixed_types(self) -> bool: """ return True + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + + If false, the op above is not needed. + """ + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index b0a739f5e6..d887442f77 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -169,7 +169,13 @@ def mixed_types(self) -> bool: """ return False + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return False def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 2c8f52709f..46520d9e89 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -173,7 +173,13 @@ def mixed_types(self): atomic types or not. """ return self.seat.mixed_types() + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return False def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.seat.get_env_protection() @@ -503,7 +509,13 @@ def mixed_types(self) -> bool: """ return False + def has_message_passing(self) -> bool: + """ + If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference. + If false, the op above is not needed. + """ + return False def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 989789c201..000b6019fa 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -530,6 +530,13 @@ def mixed_types(self) -> bool: """ return self.atomic_model.mixed_types() + def has_message_passing(self) -> bool: + """ + If true, the model has a structure of message-passing network, which requires communication op in multi-process inference. + + If false, the op above is not needed. + """ + return self.atomic_model.has_message_passing() def forward( self, diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index d3670737ba..f927408e14 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -56,3 +56,7 @@ def get_min_nbor_dist(self) -> Optional[float]: def get_ntypes(self): """Returns the number of element types.""" return len(self.get_type_map()) + @torch.jit.export + def has_message_passing(self): + """Returns the message_passing atrribute of model.""" + return self.has_message_passing() diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 232c9efd31..48800c4507 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -59,9 +59,7 @@ void DeepPotPT::init(const std::string& model, } std::unordered_map metadata = {{"type", ""}}; module = torch::jit::load(model, device, metadata); - // TODO: This should be fixed after implement api to decide whether need to - // message passing and rename this metadata - if (metadata["type"] == "dpa2") { + if (metadata["has_message_passing"] == "True") { do_message_passing = 1; } else { do_message_passing = 0;