Skip to content

Commit

Permalink
feat: add has_message_passing API (deepmodeling#3851)
Browse files Browse the repository at this point in the history
Fix deepmodeling#3713. Fix deepmodeling#3733. This is a breaking change.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a new method `has_message_passing` across various
descriptor classes. This method returns a boolean indicating whether the
descriptor has message passing capability.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 0b962b3 commit 7432a80
Show file tree
Hide file tree
Showing 39 changed files with 266 additions and 111 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def call(
):
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down Expand Up @@ -886,6 +890,10 @@ def call(
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False


class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,12 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def call(
rot_mat = np.transpose(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True


# translated by GPT and modified
def get_residual(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.get_model_def_script():
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def forward(
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""


def make_default_type_embedding(
ntypes,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,7 @@ def get_stats(self) -> Dict[str, StatItem]:
"The statistics of the descriptor has not been computed."
)
return self.stats

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def mixed_types(self):
"""
return self.sea.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.sea.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down Expand Up @@ -674,3 +678,7 @@ def forward(
None,
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,10 @@ def forward(
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False


class NeighborGatedAttention(nn.Module):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def mixed_types(self) -> bool:
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def mixed_types(self):
"""
return self.seat.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.seat.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.seat.get_env_protection()
Expand Down Expand Up @@ -687,3 +691,7 @@ def forward(
None,
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def mixed_types(self) -> bool:
"""
return self.model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.model.has_message_passing()

@torch.jit.export
def forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,11 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def forward(
self,
coord,
Expand Down
8 changes: 1 addition & 7 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,7 @@ void DeepPotPT::init(const std::string& model,
}
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
module = torch::jit::load(model, device, metadata);
// TODO: This should be fixed after implement api to decide whether need to
// message passing and rename this metadata
if (metadata["type"] == "dpa2") {
do_message_passing = 1;
} else {
do_message_passing = 0;
}
do_message_passing = module.run_method("has_message_passing").toBool();
torch::jit::FusionStrategy strategy;
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};
torch::jit::setFusionStrategy(strategy);
Expand Down
Loading

0 comments on commit 7432a80

Please sign in to comment.