Skip to content

Commit

Permalink
add split compress
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 31, 2024
1 parent ef48cca commit 22c3f87
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 17 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
update_residual: float = 0.1,
update_residual_init: str = "const",
skip_stat: bool = False,
a_compress_use_split: bool = False,
) -> None:
r"""The constructor for the RepFlowArgs class which defines the parameters of the repflow block in DPA3 descriptor.
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
self.a_mess_has_n = a_mess_has_n
self.a_use_e_mess = a_use_e_mess
self.a_compress_e_rate = a_compress_e_rate
self.a_compress_use_split = a_compress_use_split

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def init_subclass_params(sub_data, sub_class):
a_compress_e_rate=self.repflow_args.a_compress_e_rate,
a_mess_has_n=self.repflow_args.a_mess_has_n,
a_use_e_mess=self.repflow_args.a_use_e_mess,
a_compress_use_split=self.repflow_args.a_compress_use_split,
n_multi_edge_message=self.repflow_args.n_multi_edge_message,
axis_neuron=self.repflow_args.axis_neuron,
update_angle=self.repflow_args.update_angle,
Expand Down
55 changes: 38 additions & 17 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
a_compress_rate: int = 0,
a_mess_has_n: bool = True,
a_use_e_mess: bool = False,
a_compress_use_split: bool = False,
a_compress_e_rate: int = 1,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
self.update_residual_init = update_residual_init
self.a_mess_has_n = a_mess_has_n
self.a_compress_e_rate = a_compress_e_rate
self.a_compress_use_split = a_compress_use_split
self.precision = precision
self.seed = seed
self.prec = PRECISION_DICT[precision]
Expand Down Expand Up @@ -191,6 +193,8 @@ def __init__(
self.angle_dim += 2 * self.e_dim
self.a_compress_n_linear = None
self.a_compress_e_linear = None
self.e_a_compress_dim = 0
self.n_a_compress_dim = 0
else:
# angle + node/c + edge/2c * 2
# node : node/c or 0
Expand All @@ -201,20 +205,28 @@ def __init__(
self.angle_dim += (
self.a_dim // self.a_compress_rate
) * self.a_compress_e_rate
self.a_compress_n_linear = MLPLayer(
self.n_dim,
self.a_dim // self.a_compress_rate,
precision=precision,
bias=False,
seed=child_seed(seed, 8),
)
self.a_compress_e_linear = MLPLayer(
self.e_dim,
self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate,
precision=precision,
bias=False,
seed=child_seed(seed, 9),
self.e_a_compress_dim = (
self.a_dim // (2 * self.a_compress_rate) * self.a_compress_e_rate
)
self.n_a_compress_dim = self.a_dim // self.a_compress_rate
if not self.a_compress_use_split:
self.a_compress_n_linear = MLPLayer(
self.n_dim,
self.n_a_compress_dim,
precision=precision,
bias=False,
seed=child_seed(seed, 8),
)
self.a_compress_e_linear = MLPLayer(
self.e_dim,
self.e_a_compress_dim,
precision=precision,
bias=False,
seed=child_seed(seed, 9),
)
else:
self.a_compress_n_linear = None
self.a_compress_e_linear = None

# edge angle message
self.edge_angle_linear1 = MLPLayer(
Expand Down Expand Up @@ -510,10 +522,19 @@ def forward(
edge_ebd_for_a_before_cp = edge_ebd
# get angle info
if self.a_compress_rate != 0:
assert self.a_compress_n_linear is not None
assert self.a_compress_e_linear is not None
node_ebd_for_angle = self.a_compress_n_linear(node_ebd)
edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_for_a_before_cp)
if not self.a_compress_use_split:
assert self.a_compress_n_linear is not None
assert self.a_compress_e_linear is not None
node_ebd_for_angle = self.a_compress_n_linear(node_ebd)
edge_ebd_for_angle = self.a_compress_e_linear(
edge_ebd_for_a_before_cp
)
else:
# use the first a_compress_dim dim for node and edge
node_ebd_for_angle = node_ebd[:, :, : self.n_a_compress_dim]
edge_ebd_for_angle = edge_ebd_for_a_before_cp[
:, :, :, : self.e_a_compress_dim
]
else:
node_ebd_for_angle = node_ebd
edge_ebd_for_angle = edge_ebd_for_a_before_cp
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
a_compress_e_rate: int = 1,
a_mess_has_n: bool = True,
a_use_e_mess: bool = False,
a_compress_use_split: bool = False,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
update_angle: bool = True,
Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
self.a_mess_has_n = a_mess_has_n
self.a_use_e_mess = a_use_e_mess
self.a_compress_e_rate = a_compress_e_rate
self.a_compress_use_split = a_compress_use_split

self.n_dim = n_dim
self.e_dim = e_dim
Expand Down Expand Up @@ -238,6 +240,7 @@ def __init__(
a_compress_rate=self.a_compress_rate,
a_mess_has_n=self.a_mess_has_n,
a_use_e_mess=self.a_use_e_mess,
a_compress_use_split=self.a_compress_use_split,
a_compress_e_rate=self.a_compress_e_rate,
n_multi_edge_message=self.n_multi_edge_message,
axis_neuron=self.axis_neuron,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,6 +1547,12 @@ def dpa3_repflow_args():
optional=True,
default=False,
),
Argument(
"a_compress_use_split",
bool,
optional=True,
default=False,
),
Argument(
"a_compress_e_rate",
int,
Expand Down

0 comments on commit 22c3f87

Please sign in to comment.