Skip to content

Commit

Permalink
add pipeline mess
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 30, 2024
1 parent 0d60dbb commit ef48cca
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 2 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
a_sel: int = 20,
a_compress_rate: int = 0,
a_mess_has_n: bool = True,
a_use_e_mess: bool = False,
a_compress_e_rate: int = 1,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self.update_residual_init = update_residual_init
self.skip_stat = skip_stat
self.a_mess_has_n = a_mess_has_n
self.a_use_e_mess = a_use_e_mess
self.a_compress_e_rate = a_compress_e_rate

def __getitem__(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def init_subclass_params(sub_data, sub_class):
a_compress_rate=self.repflow_args.a_compress_rate,
a_compress_e_rate=self.repflow_args.a_compress_e_rate,
a_mess_has_n=self.repflow_args.a_mess_has_n,
a_use_e_mess=self.repflow_args.a_use_e_mess,
n_multi_edge_message=self.repflow_args.n_multi_edge_message,
axis_neuron=self.repflow_args.axis_neuron,
update_angle=self.repflow_args.update_angle,
Expand Down
10 changes: 8 additions & 2 deletions deepmd/pt/model/descriptor/repflow_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
a_dim: int = 64,
a_compress_rate: int = 0,
a_mess_has_n: bool = True,
a_use_e_mess: bool = False,
a_compress_e_rate: int = 1,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
self.e_dim = e_dim
self.a_dim = a_dim
self.a_compress_rate = a_compress_rate
self.a_use_e_mess = a_use_e_mess
if a_compress_rate != 0:
assert a_dim % (2 * a_compress_rate) == 0, (
f"For a_compress_rate of {a_compress_rate}, a_dim must be divisible by {2 * a_compress_rate}. "
Expand Down Expand Up @@ -502,15 +504,19 @@ def forward(
assert self.angle_self_linear is not None
assert self.edge_angle_linear1 is not None
assert self.edge_angle_linear2 is not None
if self.a_use_e_mess:
edge_ebd_for_a_before_cp = edge_self_update
else:
edge_ebd_for_a_before_cp = edge_ebd
# get angle info
if self.a_compress_rate != 0:
assert self.a_compress_n_linear is not None
assert self.a_compress_e_linear is not None
node_ebd_for_angle = self.a_compress_n_linear(node_ebd)
edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd)
edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_for_a_before_cp)
else:
node_ebd_for_angle = node_ebd
edge_ebd_for_angle = edge_ebd
edge_ebd_for_angle = edge_ebd_for_a_before_cp

# nb x nloc x a_nnei x a_nnei x n_dim
node_for_angle_info = torch.tile(
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
a_compress_rate: int = 0,
a_compress_e_rate: int = 1,
a_mess_has_n: bool = True,
a_use_e_mess: bool = False,
n_multi_edge_message: int = 1,
axis_neuron: int = 4,
update_angle: bool = True,
Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.skip_stat = skip_stat
self.a_mess_has_n = a_mess_has_n
self.a_use_e_mess = a_use_e_mess
self.a_compress_e_rate = a_compress_e_rate

self.n_dim = n_dim
Expand Down Expand Up @@ -235,6 +237,7 @@ def __init__(
a_dim=self.a_dim,
a_compress_rate=self.a_compress_rate,
a_mess_has_n=self.a_mess_has_n,
a_use_e_mess=self.a_use_e_mess,
a_compress_e_rate=self.a_compress_e_rate,
n_multi_edge_message=self.n_multi_edge_message,
axis_neuron=self.axis_neuron,
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,6 +1541,12 @@ def dpa3_repflow_args():
optional=True,
default=True,
),
Argument(
"a_use_e_mess",
bool,
optional=True,
default=False,
),
Argument(
"a_compress_e_rate",
int,
Expand Down

0 comments on commit ef48cca

Please sign in to comment.