Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add has_message_passing API #3851

Merged
merged 6 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

Check warning on line 97 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L97

Added line #L97 was not covered by tests

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

Check warning on line 136 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L136

Added line #L136 was not covered by tests

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def call(
):
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

Check warning on line 354 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L354

Added line #L354 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down Expand Up @@ -885,6 +889,10 @@
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

Check warning on line 895 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L895

Added line #L895 was not covered by tests

class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,12 @@
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(

Check warning on line 529 in deepmd/dpmodel/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa2.py#L529

Added line #L529 was not covered by tests
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

Check warning on line 143 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L143

Added line #L143 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@
rot_mat = np.transpose(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True

Check warning on line 392 in deepmd/dpmodel/descriptor/repformers.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/repformers.py#L392

Added line #L392 was not covered by tests

njzjz marked this conversation as resolved.
Show resolved Hide resolved

# translated by GPT and modified
def get_residual(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 191 in deepmd/dpmodel/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_t.py#L191

Added line #L191 was not covered by tests
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.get_model_def_script():
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def forward(
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""


def make_default_type_embedding(
ntypes,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,7 @@ def get_stats(self) -> Dict[str, StatItem]:
"The statistics of the descriptor has not been computed."
)
return self.stats

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ def mixed_types(self):
"""
return self.sea.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.sea.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down Expand Up @@ -674,3 +678,7 @@ def forward(
None,
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,10 @@ def forward(
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False


class NeighborGatedAttention(nn.Module):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

Check warning on line 175 in deepmd/pt/model/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_r.py#L175

Added line #L175 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@
"""
return self.seat.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.seat.has_message_passing()

Check warning on line 179 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L179

Added line #L179 was not covered by tests

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.seat.get_env_protection()
Expand Down Expand Up @@ -687,3 +691,7 @@
None,
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

Check warning on line 697 in deepmd/pt/model/descriptor/se_t.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/se_t.py#L697

Added line #L697 was not covered by tests
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@
"""
return self.model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.model.has_message_passing()

Check warning on line 112 in deepmd/pt/model/model/frozen.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/frozen.py#L112

Added line #L112 was not covered by tests

@torch.jit.export
def forward(
self,
Expand Down
5 changes: 5 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,11 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

@torch.jit.export
def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def forward(
self,
coord,
Expand Down
8 changes: 1 addition & 7 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,7 @@ void DeepPotPT::init(const std::string& model,
}
std::unordered_map<std::string, std::string> metadata = {{"type", ""}};
module = torch::jit::load(model, device, metadata);
// TODO: This should be fixed after implement api to decide whether need to
// message passing and rename this metadata
if (metadata["type"] == "dpa2") {
do_message_passing = 1;
} else {
do_message_passing = 0;
}
do_message_passing = module.run_method("has_message_passing").toBool();
torch::jit::FusionStrategy strategy;
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 10}};
torch::jit::setFusionStrategy(strategy);
Expand Down
Loading
Loading