Skip to content

Commit

Permalink
Add 4 pt descriptor compression (#4227)
Browse files Browse the repository at this point in the history
se_a, se_atten(DPA1), se_t, se_r

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced a model compression feature across multiple descriptor
classes, enhancing performance and efficiency.
- Added `enable_compression` methods to various classes, allowing users
to enable and configure compression settings.

- **Bug Fixes**
- Improved error handling for unsupported compression scenarios and
parameter validation.

- **Tests**
- Added comprehensive unit tests for new compression functionalities
across multiple descriptor classes to ensure accuracy and reliability.

- **Documentation**
- Enhanced documentation for new methods and classes to clarify usage
and parameters related to compression.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Yan Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent eb2832b commit 8355947
Show file tree
Hide file tree
Showing 15 changed files with 2,377 additions and 388 deletions.
25 changes: 25 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,31 @@ def compute_input_stats(
"""Update mean and stddev for descriptor elements."""
raise NotImplementedError

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
raise NotImplementedError("This descriptor doesn't support compression!")

@abstractmethod
def fwd(
self,
Expand Down
87 changes: 87 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
from deepmd.pt.utils.env import (
RESERVED_PRECISON_DICT,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -261,6 +267,8 @@ def __init__(
if ln_eps is None:
ln_eps = 1e-5

self.tebd_input_mode = tebd_input_mode

del type, spin, attn_mask
self.se_atten = DescrptBlockSeAtten(
rcut,
Expand Down Expand Up @@ -293,6 +301,7 @@ def __init__(
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
self.compress = False
self.type_embedding = TypeEmbedNet(
ntypes,
tebd_dim,
Expand Down Expand Up @@ -551,6 +560,84 @@ def t_cvt(xx):
)
return obj

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
# do some checks before the mocel compression process
if self.compress:
raise ValueError("Compression is already enabled.")
assert (
not self.se_atten.resnet_dt
), "Model compression error: descriptor resnet_dt must be false!"
for tt in self.se_atten.exclude_types:
if (tt[0] not in range(self.se_atten.ntypes)) or (
tt[1] not in range(self.se_atten.ntypes)
):
raise RuntimeError(
"exclude types"
+ str(tt)
+ " must within the number of atomic types "
+ str(self.se_atten.ntypes)
+ "!"
)
if (
self.se_atten.ntypes * self.se_atten.ntypes
- len(self.se_atten.exclude_types)
== 0
):
raise RuntimeError(
"Empty embedding-nets are not supported in model compression!"
)

if self.se_atten.attn_layer != 0:
raise RuntimeError("Cannot compress model when attention layer is not 0.")

if self.tebd_input_mode != "strip":
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")

data = self.serialize()
self.table = DPTabulate(
self,
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)

self.se_atten.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True

def forward(
self,
extended_coord: torch.Tensor,
Expand Down
134 changes: 130 additions & 4 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,34 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.tabulate import (
DPTabulate,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)

from .base_descriptor import (
BaseDescriptor,
)

if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):

def tabulate_fusion_se_a(
argument0,
argument1,
argument2,
argument3,
argument4,
) -> list[torch.Tensor]:
raise NotImplementedError(
"tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
"See documentation for model compression for details."
)

# Note: this hack cannot actually save a model that can be runned using LAMMPS.
torch.ops.deepmd.tabulate_fusion_se_a = tabulate_fusion_se_a


@BaseDescriptor.register("se_e2_a")
@BaseDescriptor.register("se_a")
Expand Down Expand Up @@ -93,6 +116,7 @@ def __init__(
raise NotImplementedError("old implementation of spin is not supported.")
super().__init__()
self.type_map = type_map
self.compress = False
self.sea = DescrptBlockSeA(
rcut,
rcut_smth,
Expand Down Expand Up @@ -225,6 +249,53 @@ def reinit_exclude(
"""Update the type exclusions."""
self.sea.reinit_exclude(exclude_types)

def enable_compression(
self,
min_nbor_dist: float,
table_extrapolate: float = 5,
table_stride_1: float = 0.01,
table_stride_2: float = 0.1,
check_frequency: int = -1,
) -> None:
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
Parameters
----------
min_nbor_dist
The nearest distance between atoms
table_extrapolate
The scale of model extrapolation
table_stride_1
The uniform stride of the first table
table_stride_2
The uniform stride of the second table
check_frequency
The overflow check frequency
"""
if self.compress:
raise ValueError("Compression is already enabled.")
data = self.serialize()
self.table = DPTabulate(
self,
data["neuron"],
data["type_one_side"],
data["exclude_types"],
ActivationFn(data["activation_function"]),
)
self.table_config = [
table_extrapolate,
table_stride_1,
table_stride_2,
check_frequency,
]
self.lower, self.upper = self.table.build(
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
)
self.sea.enable_compression(
self.table.data, self.table_config, self.lower, self.upper
)
self.compress = True

def forward(
self,
coord_ext: torch.Tensor,
Expand Down Expand Up @@ -366,6 +437,10 @@ def update_sel(
class DescrptBlockSeA(DescriptorBlock):
ndescrpt: Final[int]
__constants__: ClassVar[list] = ["ndescrpt"]
lower: dict[str, int]
upper: dict[str, int]
table_data: dict[str, torch.Tensor]
table_config: list[Union[int, float]]

def __init__(
self,
Expand Down Expand Up @@ -425,6 +500,13 @@ def __init__(
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)

# add for compression
self.compress = False
self.lower = {}
self.upper = {}
self.table_data = {}
self.table_config = []

ndim = 1 if self.type_one_side else 2
filter_layers = NetworkCollection(
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
Expand All @@ -443,6 +525,7 @@ def __init__(
self.filter_layers = filter_layers
self.stats = None
# set trainable
self.trainable = trainable
for param in self.parameters():
param.requires_grad = trainable

Expand Down Expand Up @@ -470,6 +553,10 @@ def get_dim_out(self) -> int:
"""Returns the output dimension."""
return self.dim_out

def get_dim_rot_mat_1(self) -> int:
"""Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3."""
return self.filter_neuron[-1]

def get_dim_emb(self) -> int:
"""Returns the output dimension."""
return self.neuron[-1]
Expand Down Expand Up @@ -578,6 +665,19 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

def enable_compression(
self,
table_data,
table_config,
lower,
upper,
) -> None:
self.compress = True
self.table_data = table_data
self.table_config = table_config
self.lower = lower
self.upper = upper

def forward(
self,
nlist: torch.Tensor,
Expand Down Expand Up @@ -627,6 +727,7 @@ def forward(
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
ti = -1
# torch.jit is not happy with slice(None)
# ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
# applying a mask seems to cause performance degradation
Expand All @@ -648,10 +749,35 @@ def forward(
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)

if self.compress:
if self.type_one_side:
net = "filter_-1_net_" + str(ii)
else:
net = "filter_" + str(ti) + "_net_" + str(ii)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf
tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec)
gr = torch.ops.deepmd.tabulate_fusion_se_a(
tensor_data.contiguous(),
torch.tensor(info, dtype=self.prec, device="cpu").contiguous(),
ss.contiguous(),
rr.contiguous(),
self.filter_neuron[-1],
)[0]
else:
# nfnl x nt x ng
gg = ll.forward(ss)
# nfnl x 4 x ng
gr = torch.matmul(rr.permute(0, 2, 1), gg)

if ti_mask is not None:
xyz_scatter[ti_mask] += gr
else:
Expand Down
Loading

0 comments on commit 8355947

Please sign in to comment.