Skip to content

Commit

Permalink
Add se_attn_v2 in dp
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed May 30, 2024
1 parent 2e8396e commit 87c3561
Showing 1 changed file with 58 additions and 6 deletions.
64 changes: 58 additions & 6 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def np_normalize(x, axis=-1):
return x / np.linalg.norm(x, axis=axis, keepdims=True)


@BaseDescriptor.register("se_atten")
@BaseDescriptor.register("se_atten_v2")
@BaseDescriptor.register("dpa1")
class DescrptDPA1(NativeOP, BaseDescriptor):
r"""Attention-based descriptor which is proposed in the pretrainable DPA-1[1] model.
Expand Down Expand Up @@ -226,7 +226,6 @@ def __init__(
neuron: List[int] = [25, 50, 100],
axis_neuron: int = 8,
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
resnet_dt: bool = False,
trainable: bool = True,
type_one_side: bool = False,
Expand All @@ -244,7 +243,6 @@ def __init__(
temperature: Optional[float] = None,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
smooth_type_embedding: bool = True,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
Expand All @@ -268,15 +266,14 @@ def __init__(
if ln_eps is None:
ln_eps = 1e-5

self.se_atten = DescrptBlockSeAtten(
self.se_atten = DescrptBlockSeAttenV2(
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode=tebd_input_mode,
set_davg_zero=set_davg_zero,
attn=attn,
attn_layer=attn_layer,
Expand All @@ -288,7 +285,6 @@ def __init__(
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
smooth=smooth_type_embedding,
type_one_side=type_one_side,
exclude_types=exclude_types,
env_protection=env_protection,
Expand Down Expand Up @@ -870,6 +866,62 @@ def call(
sw,
)

@DescriptorBlock.register("se_atten_v2")
class DescrptBlockSeAttenV2(DescrptBlockSeAtten):
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[List[int], int],
ntypes: int,
neuron: List[int] = [25, 50, 100],
axis_neuron: int = 8,
tebd_dim: int = 8,
resnet_dt: bool = False,
type_one_side: bool = False,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
scaling_factor=1.0,
normalize: bool = True,
temperature: Optional[float] = None,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
) -> None:
DescrptBlockSeAtten.__init__(
self,
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode="strip",
resnet_dt=resnet_dt,
type_one_side=type_one_side,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
exclude_types=exclude_types,
env_protection=env_protection,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
smooth=True,
)

class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down

0 comments on commit 87c3561

Please sign in to comment.