Skip to content

Commit

Permalink
apply compression for se_e2_a_tebd (#2841)
Browse files Browse the repository at this point in the history
  • Loading branch information
nahso authored Sep 27, 2023
1 parent 498bfa0 commit e937345
Show file tree
Hide file tree
Showing 12 changed files with 1,758 additions and 108 deletions.
4 changes: 4 additions & 0 deletions deepmd/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .se_a_ebd import (
DescrptSeAEbd,
)
from .se_a_ebd_v2 import (
DescrptSeAEbdV2,
)
from .se_a_ef import (
DescrptSeAEf,
DescrptSeAEfLower,
Expand All @@ -39,6 +42,7 @@
"DescrptHybrid",
"DescrptLocFrame",
"DescrptSeA",
"DescrptSeAEbdV2",
"DescrptSeAEbd",
"DescrptSeAEf",
"DescrptSeAEfLower",
Expand Down
210 changes: 186 additions & 24 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from deepmd.common import (
cast_precision,
get_activation_func,
get_np_precision,
get_precision,
)
from deepmd.env import (
Expand All @@ -30,10 +31,17 @@
from deepmd.nvnmd.utils.config import (
nvnmd_cfg,
)
from deepmd.utils.compress import (
get_extra_side_embedding_net_variable,
get_two_side_type_embedding,
get_type_embedding,
make_data,
)
from deepmd.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.utils.graph import (
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
)
from deepmd.utils.network import (
Expand Down Expand Up @@ -165,6 +173,7 @@ def __init__(
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
stripped_type_embedding: bool = False,
**kwargs,
) -> None:
"""Constructor."""
Expand All @@ -185,6 +194,7 @@ def __init__(
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
Expand All @@ -193,6 +203,9 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.type_one_side = type_one_side
self.spin = spin
self.stripped_type_embedding = stripped_type_embedding
self.extra_embeeding_net_variables = None
self.layer_size = len(neuron)

# extend sel_a for spin system
if self.spin is not None:
Expand Down Expand Up @@ -463,6 +476,39 @@ def enable_compression(
"The size of the next layer of the neural network must be twice the size of the previous layer."
% ",".join([str(item) for item in self.filter_neuron])
)
if self.stripped_type_embedding:
ret_two_side = get_pattern_nodes_from_graph_def(
graph_def, f"filter_type_all{suffix}/.+_two_side_ebd"
)
ret_one_side = get_pattern_nodes_from_graph_def(
graph_def, f"filter_type_all{suffix}/.+_one_side_ebd"
)
if len(ret_two_side) == 0 and len(ret_one_side) == 0:
raise RuntimeError(
"can not find variables of embedding net from graph_def, maybe it is not a compressible model."
)
elif len(ret_one_side) != 0 and len(ret_two_side) != 0:
raise RuntimeError(
"both one side and two side embedding net varaibles are detected, it is a wrong model."
)
elif len(ret_two_side) != 0:
self.final_type_embedding = get_two_side_type_embedding(self, graph)
self.matrix = get_extra_side_embedding_net_variable(
self, graph_def, "two_side", "matrix", suffix
)
self.bias = get_extra_side_embedding_net_variable(
self, graph_def, "two_side", "bias", suffix
)
self.extra_embedding = make_data(self, self.final_type_embedding)
else:
self.final_type_embedding = get_type_embedding(self, graph)
self.matrix = get_extra_side_embedding_net_variable(
self, graph_def, "one_side", "matrix", suffix
)
self.bias = get_extra_side_embedding_net_variable(
self, graph_def, "one_side", "bias", suffix
)
self.extra_embedding = make_data(self, self.final_type_embedding)

self.compress = True
self.table = DPTabulate(
Expand Down Expand Up @@ -588,6 +634,7 @@ def build(
coord = tf.reshape(coord_, [-1, natoms[1] * 3])
box = tf.reshape(box_, [-1, 9])
atype = tf.reshape(atype_, [-1, natoms[1]])
self.atype = atype

op_descriptor = (
build_op_descriptor() if nvnmd_cfg.enable else op_module.prod_env_mat_a
Expand All @@ -606,6 +653,10 @@ def build(
sel_a=self.sel_a,
sel_r=self.sel_r,
)
nlist_t = tf.reshape(self.nlist + 1, [-1])
atype_t = tf.concat([[self.ntypes], tf.reshape(self.atype, [-1])], axis=0)
self.nei_type_vec = tf.nn.embedding_lookup(atype_t, nlist_t)

# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
tf.summary.histogram("rij", self.rij)
Expand Down Expand Up @@ -692,6 +743,8 @@ def _pass_filter(
type_embedding = input_dict.get("type_embedding", None)
else:
type_embedding = None
if self.stripped_type_embedding and type_embedding is None:
raise RuntimeError("type_embedding is required for se_a_tebd_v2 model.")
start_index = 0
inputs = tf.reshape(inputs, [-1, natoms[0], self.ndescrpt])
output = []
Expand Down Expand Up @@ -901,13 +954,89 @@ def _filter_lower(
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1])
if type_embedding is not None:
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding
)
if self.compress:
raise RuntimeError(
"compression of type embedded descriptor is not supported at the moment"
if self.stripped_type_embedding:
if self.type_one_side:
extra_embedding_index = self.nei_type_vec
else:
padding_ntypes = type_embedding.shape[0]
atype_expand = tf.reshape(self.atype, [-1, 1])
idx_i = tf.tile(atype_expand * padding_ntypes, [1, self.nnei])
idx_j = tf.reshape(self.nei_type_vec, [-1, self.nnei])
idx = idx_i + idx_j
index_of_two_side = tf.reshape(idx, [-1])
extra_embedding_index = index_of_two_side

if not self.compress:
if self.type_one_side:
one_side_type_embedding_suffix = "_one_side_ebd"
net_output = embedding_net(
type_embedding,
self.filter_neuron,
self.filter_precision,
activation_fn=activation_fn,
resnet_dt=self.filter_resnet_dt,
name_suffix=one_side_type_embedding_suffix,
stddev=stddev,
bavg=bavg,
seed=self.seed,
trainable=trainable,
uniform_seed=self.uniform_seed,
initial_variables=self.extra_embeeding_net_variables,
mixed_prec=self.mixed_prec,
)
net_output = tf.nn.embedding_lookup(
net_output, self.nei_type_vec
)
else:
type_embedding_nei = tf.tile(
tf.reshape(type_embedding, [1, padding_ntypes, -1]),
[padding_ntypes, 1, 1],
) # (ntypes) * ntypes * Y
type_embedding_center = tf.tile(
tf.reshape(type_embedding, [padding_ntypes, 1, -1]),
[1, padding_ntypes, 1],
) # ntypes * (ntypes) * Y
two_side_type_embedding = tf.concat(
[type_embedding_nei, type_embedding_center], -1
) # ntypes * ntypes * (Y+Y)
two_side_type_embedding = tf.reshape(
two_side_type_embedding,
[-1, two_side_type_embedding.shape[-1]],
)

atype_expand = tf.reshape(self.atype, [-1, 1])
idx_i = tf.tile(atype_expand * padding_ntypes, [1, self.nnei])
idx_j = tf.reshape(self.nei_type_vec, [-1, self.nnei])
idx = idx_i + idx_j
index_of_two_side = tf.reshape(idx, [-1])
self.extra_embedding_index = index_of_two_side

two_side_type_embedding_suffix = "_two_side_ebd"
net_output = embedding_net(
two_side_type_embedding,
self.filter_neuron,
self.filter_precision,
activation_fn=activation_fn,
resnet_dt=self.filter_resnet_dt,
name_suffix=two_side_type_embedding_suffix,
stddev=stddev,
bavg=bavg,
seed=self.seed,
trainable=trainable,
uniform_seed=self.uniform_seed,
initial_variables=self.extra_embeeding_net_variables,
mixed_prec=self.mixed_prec,
)
net_output = tf.nn.embedding_lookup(net_output, idx)
net_output = tf.reshape(net_output, [-1, self.filter_neuron[-1]])
else:
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding
)
if self.compress:
raise RuntimeError(
"compression of type embedded descriptor is not supported when stripped_type_embedding == False"
)
# natom x 4 x outputs_size
if nvnmd_cfg.enable:
return filter_lower_R42GR(
Expand All @@ -929,25 +1058,48 @@ def _filter_lower(
self.embedding_net_variables,
)
if self.compress and (not is_exclude):
if self.type_one_side:
net = "filter_-1_net_" + str(type_i)
if self.stripped_type_embedding:
net_output = tf.nn.embedding_lookup(
self.extra_embedding, extra_embedding_index
)
net = "filter_net"
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
return op_module.tabulate_fusion_se_atten(
tf.cast(self.table.data[net], self.filter_precision),
info,
xyz_scatter,
tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
net_output,
last_layer_size=outputs_size[-1],
is_sorted=False,
)
else:
net = "filter_" + str(type_input) + "_net_" + str(type_i)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
return op_module.tabulate_fusion_se_a(
tf.cast(self.table.data[net], self.filter_precision),
info,
xyz_scatter,
tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
last_layer_size=outputs_size[-1],
)
if self.type_one_side:
net = "filter_-1_net_" + str(type_i)
else:
net = "filter_" + str(type_input) + "_net_" + str(type_i)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
return op_module.tabulate_fusion_se_a(
tf.cast(self.table.data[net], self.filter_precision),
info,
xyz_scatter,
tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
last_layer_size=outputs_size[-1],
)
else:
if not is_exclude:
# with (natom x nei_type_i) x out_size
Expand All @@ -966,6 +1118,9 @@ def _filter_lower(
initial_variables=self.embedding_net_variables,
mixed_prec=self.mixed_prec,
)

if self.stripped_type_embedding:
xyz_scatter = xyz_scatter * net_output + xyz_scatter
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
else:
Expand Down Expand Up @@ -1179,3 +1334,10 @@ def init_variables(
self.dstd = new_dstd
if self.original_sel is None:
self.original_sel = sel

@property
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
if self.stripped_type_embedding:
return True
return False
70 changes: 70 additions & 0 deletions deepmd/descriptor/se_a_ebd_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
)

from deepmd.utils.spin import (
Spin,
)

from .descriptor import (
Descriptor,
)
from .se_a import (
DescrptSeA,
)

log = logging.getLogger(__name__)


@Descriptor.register("se_a_tpe_v2")
@Descriptor.register("se_a_ebd_v2")
class DescrptSeAEbdV2(DescrptSeA):
r"""A compressible se_a_ebd model.
This model is a warpper for DescriptorSeA, which set stripped_type_embedding=True.
"""

def __init__(
self,
rcut: float,
rcut_smth: float,
sel: List[str],
neuron: List[int] = [24, 48, 96],
axis_neuron: int = 8,
resnet_dt: bool = False,
trainable: bool = True,
seed: Optional[int] = None,
type_one_side: bool = True,
exclude_types: List[List[int]] = [],
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
**kwargs,
) -> None:
DescrptSeA.__init__(
self,
rcut,
rcut_smth,
sel,
neuron=neuron,
axis_neuron=axis_neuron,
resnet_dt=resnet_dt,
trainable=trainable,
seed=seed,
type_one_side=type_one_side,
exclude_types=exclude_types,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
uniform_seed=uniform_seed,
multi_task=multi_task,
spin=spin,
stripped_type_embedding=True,
**kwargs,
)
Loading

0 comments on commit e937345

Please sign in to comment.