Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: consistent type embedding #3617

Merged
merged 5 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L2

Added line #L2 was not covered by tests
List,
Optional,
)

import numpy as np

Check warning on line 7 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L7

Added line #L7 was not covered by tests

from deepmd.dpmodel.common import (

Check warning on line 9 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L9

Added line #L9 was not covered by tests
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.utils.network import (

Check warning on line 13 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L13

Added line #L13 was not covered by tests
EmbeddingNet,
)
from deepmd.utils.version import (

Check warning on line 16 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L16

Added line #L16 was not covered by tests
check_version_compatibility,
)


class TypeEmbedNet(NativeOP):

Check warning on line 21 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L21

Added line #L21 was not covered by tests
r"""Type embedding network.

Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(

Check warning on line 44 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L44

Added line #L44 was not covered by tests
self,
*,
ntypes: int,
neuron: List[int] = [],
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
) -> None:
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding
self.embedding_net = EmbeddingNet(

Check warning on line 64 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L56-L64

Added lines #L56 - L64 were not covered by tests
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)

def call(self) -> np.ndarray:

Check warning on line 72 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L72

Added line #L72 was not covered by tests
"""Compute the type embedding network."""
embed = self.embedding_net(

Check warning on line 74 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L74

Added line #L74 was not covered by tests
np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
)
if self.padding:
embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant")
return embed

Check warning on line 79 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L77-L79

Added lines #L77 - L79 were not covered by tests

@classmethod
def deserialize(cls, data: dict):

Check warning on line 82 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L81-L82

Added lines #L81 - L82 were not covered by tests
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
Model
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

Check warning on line 98 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L95-L98

Added lines #L95 - L98 were not covered by tests

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

Check warning on line 103 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L100-L103

Added lines #L100 - L103 were not covered by tests

def serialize(self) -> dict:

Check warning on line 105 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L105

Added line #L105 was not covered by tests
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
return {

Check warning on line 113 in deepmd/dpmodel/utils/type_embed.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/type_embed.py#L113

Added line #L113 was not covered by tests
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}
141 changes: 134 additions & 7 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

Expand All @@ -8,9 +9,15 @@
import torch.nn as nn
import torch.nn.functional as F

from deepmd.pt.model.network.mlp import (

Check warning on line 12 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L12

Added line #L12 was not covered by tests
EmbeddingNet,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.version import (

Check warning on line 18 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L18

Added line #L18 was not covered by tests
check_version_compatibility,
)

try:
from typing import (
Expand Down Expand Up @@ -553,12 +560,12 @@
def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0):
"""Construct a type embedding net."""
super().__init__()
self.embedding = nn.Embedding(
type_nums + 1,
embed_dim,
padding_idx=type_nums,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
self.embedding = TypeEmbedNetConsistent(

Check warning on line 563 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L563

Added line #L563 was not covered by tests
ntypes=type_nums,
neuron=[embed_dim],
padding=True,
activation_function="Linear",
precision="default",
)
# nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev)

Expand All @@ -572,7 +579,7 @@
type_embedding:

"""
return self.embedding(atype)
return self.embedding(atype.device)[atype]

Check warning on line 582 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L582

Added line #L582 was not covered by tests

def share_params(self, base_class, shared_level, resume=False):
"""
Expand All @@ -591,6 +598,126 @@
raise NotImplementedError


class TypeEmbedNetConsistent(nn.Module):

Check warning on line 601 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L601

Added line #L601 was not covered by tests
r"""Type embedding network that is consistent with other backends.

Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(

Check warning on line 624 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L624

Added line #L624 was not covered by tests
self,
*,
ntypes: int,
neuron: List[int] = [],
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
):
"""Construct a type embedding net."""
super().__init__()
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.prec = env.PRECISION_DICT[self.precision]
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding

Check warning on line 646 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L637-L646

Added lines #L637 - L646 were not covered by tests
# no way to pass seed?
self.embedding_net = EmbeddingNet(

Check warning on line 648 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L648

Added line #L648 was not covered by tests
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
for param in self.parameters():
param.requires_grad = trainable

Check warning on line 656 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L655-L656

Added lines #L655 - L656 were not covered by tests

def forward(self, device: torch.device):

Check warning on line 658 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L658

Added line #L658 was not covered by tests
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Caulate type embedding network.

Returns
-------
type_embedding: torch.Tensor
Type embedding network.
"""
embed = self.embedding_net(

Check warning on line 666 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L666

Added line #L666 was not covered by tests
torch.eye(self.ntypes, dtype=self.prec, device=device)
)
if self.padding:
embed = torch.cat(

Check warning on line 670 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L669-L670

Added lines #L669 - L670 were not covered by tests
[embed, torch.zeros(1, embed.shape[1], dtype=self.prec, device=device)]
)
return embed

Check warning on line 673 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L673

Added line #L673 was not covered by tests

@classmethod
def deserialize(cls, data: dict):

Check warning on line 676 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L675-L676

Added lines #L675 - L676 were not covered by tests
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
TypeEmbedNetConsistent
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

Check warning on line 692 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L689-L692

Added lines #L689 - L692 were not covered by tests

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

Check warning on line 697 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L694-L697

Added lines #L694 - L697 were not covered by tests

def serialize(self) -> dict:

Check warning on line 699 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L699

Added line #L699 was not covered by tests
"""Serialize the model.

Returns
-------
dict
The serialized data
"""
return {

Check warning on line 707 in deepmd/pt/model/network/network.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/network/network.py#L707

Added line #L707 was not covered by tests
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}


@torch.jit.script
def gaussian(x, mean, std: float):
pi = 3.14159
Expand Down
11 changes: 7 additions & 4 deletions deepmd/tf/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,14 @@ def dlopen_library(module: str, filename: str):
r"share_.+/idt|"
)[:-1]

# subpatterns:
# \1: weight name
# \2: layer index
TYPE_EMBEDDING_PATTERN = str(
r"type_embed_net+/matrix_\d+|"
r"type_embed_net+/bias_\d+|"
r"type_embed_net+/idt_\d+|"
)
r"type_embed_net/(matrix)_(\d+)|"
r"type_embed_net/(bias)_(\d+)|"
r"type_embed_net/(idt)_(\d+)|"
)[:-1]

ATTENTION_LAYER_PATTERN = str(
r"attention_layer_\d+/c_query/matrix|"
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def __init__(
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
Expand All @@ -686,6 +687,7 @@ def __init__(
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/tf/model/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,13 @@
dim_descrpt=self.descrpt.get_dim_out(),
)

self.ntypes = self.descrpt.get_ntypes()

Check warning on line 149 in deepmd/tf/model/multi.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/multi.py#L149

Added line #L149 was not covered by tests
# type embedding
if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
Expand All @@ -159,6 +161,7 @@
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
)
Expand All @@ -167,7 +170,6 @@

# descriptor
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()
# fitting
self.fitting_dict = fitting_dict
self.numb_fparam_dict = {
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@
compress=compress,
**kwargs,
)
self.ntypes = len(type_map)

Check warning on line 80 in deepmd/tf/model/pairwise_dprc.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/model/pairwise_dprc.py#L80

Added line #L80 was not covered by tests
# type embedding
if isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
else:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
# must use se_atten, so it must be True
padding=True,
Expand All @@ -100,7 +102,6 @@
compress=compress,
)
add_data_requirement("aparam", 1, atomic=True, must=True, high_prec=False)
self.ntypes = len(type_map)
self.rcut = max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())

def build(
Expand Down
Loading
Loading