Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add se_atten_v2to PyTorch and DP #3840

Merged
merged 46 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2e8396e
Add se_attn_v2 in pt
Chengqian-Zhang May 30, 2024
87c3561
Add se_attn_v2 in dp
Chengqian-Zhang May 30, 2024
9b238ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
95696ee
Delete redirect to dpa1
Chengqian-Zhang May 30, 2024
e55a891
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang May 30, 2024
08bf55a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
999adc8
Change se_attn_v2 impl in dp and pt
Chengqian-Zhang May 30, 2024
0e27dc8
Fix conflict
Chengqian-Zhang May 30, 2024
c08f61d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
59269f2
Add UT
Chengqian-Zhang May 30, 2024
30cc6af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 30, 2024
46783e3
Fix bug
Chengqian-Zhang May 30, 2024
fdc67f8
Merge branch 'devel' into 3831
Chengqian-Zhang May 31, 2024
c46fb27
Merge branch 'devel' into 3831
Chengqian-Zhang Jun 3, 2024
9ae4dc7
Add UT for consistency
Chengqian-Zhang Jun 3, 2024
7e4826e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
d6a6e9b
change sentence
Chengqian-Zhang Jun 3, 2024
9680443
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
d4696a1
Delete tf UT
Chengqian-Zhang Jun 3, 2024
d5c7ace
Solve conflict
Chengqian-Zhang Jun 3, 2024
67ad350
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
8494c44
Change Doc of se_atten_v2
Chengqian-Zhang Jun 3, 2024
0218a01
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang Jun 3, 2024
2e990cd
Delete tf se_atten_v2
Chengqian-Zhang Jun 3, 2024
890df1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
98d6daf
change attn-->atten
Chengqian-Zhang Jun 3, 2024
3d66922
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang Jun 3, 2024
a14fed2
Add serialize in se_attn_v2 tf
Chengqian-Zhang Jun 3, 2024
c642823
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
10c8909
fix pre-commit
Chengqian-Zhang Jun 3, 2024
6b9a93e
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang Jun 3, 2024
6177522
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
6813f7f
Solve conversation
Chengqian-Zhang Jun 4, 2024
45f9755
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang Jun 4, 2024
d2fd8a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2024
9d4d0e3
Solve alert
Chengqian-Zhang Jun 4, 2024
d25993f
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang Jun 4, 2024
001d21c
Solve alert
Chengqian-Zhang Jun 4, 2024
8ee1125
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 4, 2024
b0380f9
Update deepmd/dpmodel/descriptor/se_atten_v2.py
Chengqian-Zhang Jun 4, 2024
6f2473f
Update deepmd/pt/model/descriptor/se_atten_v2.py
Chengqian-Zhang Jun 4, 2024
ef94c04
Update deepmd/pt/model/descriptor/se_atten_v2.py
Chengqian-Zhang Jun 4, 2024
b39fcce
Update deepmd/utils/argcheck.py
Chengqian-Zhang Jun 4, 2024
c982d0e
Merge branch 'devel' into 3831
Chengqian-Zhang Jun 5, 2024
c3892f4
Delete doc_stripped_type_embedding
Chengqian-Zhang Jun 5, 2024
38fc569
Merge branch '3831' of github.com:Chengqian-Zhang/deepmd-kit into 3831
Chengqian-Zhang Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dpa1 import (
DescrptDPA1,
DescrptSeAttenV2,
)
from .dpa2 import (
DescrptDPA2,
Expand All @@ -22,6 +23,7 @@
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
"make_base_descriptor",
Expand Down
148 changes: 148 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,154 @@
)


@BaseDescriptor.register("se_atten_v2")
Chengqian-Zhang marked this conversation as resolved.
Show resolved Hide resolved
class DescrptSeAttenV2(DescrptDPA1):
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[List[int], int],
ntypes: int,
neuron: List[int] = [25, 50, 100],
axis_neuron: int = 8,
tebd_dim: int = 8,
resnet_dt: bool = False,
trainable: bool = True,
type_one_side: bool = False,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
scaling_factor=1.0,
normalize: bool = True,
temperature: Optional[float] = None,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
use_econf_tebd: bool = False,
type_map: Optional[List[str]] = None,
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
DescrptDPA1.__init__(
self,
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode="strip",
resnet_dt=resnet_dt,
trainable=trainable,
type_one_side=type_one_side,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
exclude_types=exclude_types,
env_protection=env_protection,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
smooth_type_embedding=True,
concat_output_tebd=concat_output_tebd,
spin=spin,
stripped_type_embedding=stripped_type_embedding,
use_econf_tebd=use_econf_tebd,
type_map=type_map,
# consistent with argcheck, not used though
seed=seed,
)

def serialize(self) -> dict:
Chengqian-Zhang marked this conversation as resolved.
Show resolved Hide resolved
"""Serialize the descriptor to dict."""
obj = self.se_atten
data = {

Check warning on line 950 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L949-L950

Added lines #L949 - L950 were not covered by tests
"@class": "Descriptor",
"type": "se_attn_v2",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": False,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"trainable_ln": obj.trainable_ln,
"ln_eps": obj.ln_eps,
"type_one_side": obj.type_one_side,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"type_map": self.type_map,
# make deterministic
"precision": np.dtype(PRECISION_DICT[obj.precision]).name,
"embeddings": obj.embeddings.serialize(),
"embeddings_strip": obj.embeddings_strip.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": obj.env_mat.serialize(),
"type_embedding": self.type_embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
},
## to be updated when the options are supported.
"trainable": self.trainable,
"spin": None,
}
return data

Check warning on line 994 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L994

Added line #L994 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
"""Deserialize from dict."""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
Fixed Show fixed Hide fixed
Chengqian-Zhang marked this conversation as resolved.
Show resolved Hide resolved
embeddings_strip = data.pop("embeddings_strip")
obj = cls(**data)

obj.se_atten["davg"] = variables["davg"]
obj.se_atten["dstd"] = variables["dstd"]
obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings)
obj.se_atten.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
return obj


class NeighborGatedAttention(NativeOP):
def __init__(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
DescrptBlockSeA,
DescrptSeA,
)
from .se_atten_v2 import (
DescrptSeAttenV2,
)
from .se_r import (
DescrptSeR,
)
Expand All @@ -39,6 +42,7 @@
"make_default_type_embedding",
"DescrptBlockSeA",
"DescrptBlockSeAtten",
"DescrptSeAttenV2",
"DescrptSeA",
"DescrptSeR",
"DescrptDPA1",
Expand Down
Loading