Skip to content

Commit

Permalink
fix message passing api
Browse files Browse the repository at this point in the history
  • Loading branch information
CaRoLZhangxy committed Jun 3, 2024
1 parent f23c77e commit 5d9b94c
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 11 deletions.
6 changes: 3 additions & 3 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.get_model_def_script():
extra_files = {"type": "dpa2"}
if model.has_message_passing() == True:
extra_files = {"has_message_passing": "True"}
else:
extra_files = {"type": "else"}
extra_files = {"has_message_passing": "False"}
torch.jit.save(
model,
FLAGS.output,
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""
If true, the model has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return self.descriptor.has_message_passing()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,17 +331,25 @@ def get_dim_emb(self) -> int:
return self.se_atten.dim_emb

def mixed_types(self) -> bool:
"""If true, the discriptor
"""If true, the descriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the discriptor
If false, the descriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
12 changes: 10 additions & 2 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,25 @@ def get_dim_emb(self) -> int:
return self.repformers.dim_emb

def mixed_types(self) -> bool:
"""If true, the discriptor
"""If true, the descriptor
1. assumes total number of atoms aligned across frames;
2. requires a neighbor list that does not distinguish different atomic types.
If false, the discriptor
If false, the descriptor
1. assumes total number of atoms of each atom type aligned across frames;
2. requires a neighbor list that distinguishes different atomic types.
"""
return True

def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return True

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def mixed_types(self):
atomic types or not.
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""
Returns if the descriptor requires message passing or not.
"""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,13 @@ def mixed_types(self) -> bool:
"""
return True
def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return True
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,13 @@ def mixed_types(self):
atomic types or not.
"""
return self.sea.mixed_types()
def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down Expand Up @@ -478,7 +484,13 @@ def mixed_types(self) -> bool:
"""
return False
def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,14 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,13 @@ def mixed_types(self) -> bool:
"""
return False
def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,13 @@ def mixed_types(self):
atomic types or not.
"""
return self.seat.mixed_types()
def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.seat.get_env_protection()
Expand Down Expand Up @@ -503,7 +509,13 @@ def mixed_types(self) -> bool:
"""
return False
def has_message_passing(self) -> bool:
"""
If true, the descriptor has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return False
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()
def has_message_passing(self) -> bool:
"""
If true, the model has a structure of message-passing network, which requires communication op in multi-process inference.
If false, the op above is not needed.
"""
return self.atomic_model.has_message_passing()

def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ def get_min_nbor_dist(self) -> Optional[float]:
def get_ntypes(self):
"""Returns the number of element types."""
return len(self.get_type_map())
@torch.jit.export
def has_message_passing(self):
"""Returns the message_passing atrribute of model."""
return self.has_message_passing()
4 changes: 1 addition & 3 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ void DeepPotPT::init(const std::string& model,
}
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
module = torch::jit::load(model, device, metadata);
// TODO: This should be fixed after implement api to decide whether need to
// message passing and rename this metadata
if (metadata["type"] == "dpa2") {
if (metadata["has_message_passing"] == "True") {
do_message_passing = 1;
} else {
do_message_passing = 0;
Expand Down

0 comments on commit 5d9b94c

Please sign in to comment.