Skip to content

Commit

Permalink
rename and add uts
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 18, 2024
1 parent e23dc5f commit 527cb85
Show file tree
Hide file tree
Showing 9 changed files with 603 additions and 581 deletions.
3 changes: 0 additions & 3 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(
a_rcut_smth: float = 3.5,
a_sel: int = 20,
axis_neuron: int = 4,
node_has_conv: bool = False,
update_angle: bool = True,
update_style: str = "res_residual",
update_residual: float = 0.1,
Expand Down Expand Up @@ -73,7 +72,6 @@ def __init__(
self.a_rcut_smth = a_rcut_smth
self.a_sel = a_sel
self.axis_neuron = axis_neuron
self.node_has_conv = node_has_conv # tmp
self.update_angle = update_angle
self.update_style = update_style
self.update_residual = update_residual
Expand All @@ -98,7 +96,6 @@ def serialize(self) -> dict:
"a_rcut_smth": self.a_rcut_smth,
"a_sel": self.a_sel,
"axis_neuron": self.axis_neuron,
"node_has_conv": self.node_has_conv, # tmp
"update_angle": self.update_angle,
"update_style": self.update_style,
"update_residual": self.update_residual,
Expand Down
25 changes: 14 additions & 11 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def init_subclass_params(sub_data, sub_class):
e_dim=self.repflow_args.e_dim,
a_dim=self.repflow_args.a_dim,
axis_neuron=self.repflow_args.axis_neuron,
node_has_conv=self.repflow_args.node_has_conv,
update_angle=self.repflow_args.update_angle,
activation_function=self.activation_function,
update_style=self.repflow_args.update_style,
Expand Down Expand Up @@ -299,7 +298,7 @@ def change_type_map(
extend_descrpt_stat(
repflow,
type_map,
des_with_stat=model_with_new_type_stat.repflow
des_with_stat=model_with_new_type_stat.repflows
if model_with_new_type_stat is not None
else None,
)
Expand Down Expand Up @@ -380,6 +379,7 @@ def serialize(self) -> dict:
}
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
"angle_embd": repflows.angle_embd.serialize(),
"repflow_layers": [layer.serialize() for layer in repflows.layers],
"env_mat": DPEnvMat(repflows.rcut, repflows.rcut_smth).serialize(),
"@variables": {
Expand Down Expand Up @@ -417,6 +417,9 @@ def t_cvt(xx):
env_mat = repflow_variable.pop("env_mat")
repflow_layers = repflow_variable.pop("repflow_layers")
obj.repflows.edge_embd = MLPLayer.deserialize(repflow_variable.pop("edge_embd"))
obj.repflows.angle_embd = MLPLayer.deserialize(
repflow_variable.pop("angle_embd")
)
obj.repflows["davg"] = t_cvt(statistic_repflows["davg"])
obj.repflows["dstd"] = t_cvt(statistic_repflows["dstd"])
obj.repflows.layers = torch.nn.ModuleList(
Expand Down Expand Up @@ -449,12 +452,12 @@ def forward(
Returns
-------
node_embd
node_ebd
The output descriptor. shape: nf x nloc x n_dim (or n_dim + tebd_dim)
rot_mat
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x e_dim x 3
edge_embd
edge_ebd
The edge embedding.
shape: nf x nloc x nnei x e_dim
h2
Expand All @@ -469,23 +472,23 @@ def forward(
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3

node_embd_ext = self.type_embedding(extended_atype)
node_embd_inp = node_embd_ext[:, :nloc, :]
node_ebd_ext = self.type_embedding(extended_atype)
node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_embd, edge_embd, h2, rot_mat, sw = self.repflows(
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
nlist,
extended_coord,
extended_atype,
node_embd_ext,
node_ebd_ext,
mapping,
comm_dict=comm_dict,
)
if self.concat_output_tebd:
node_embd = torch.cat([node_embd, node_embd_inp], dim=-1)
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)
return (
node_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
edge_embd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
)
Expand Down
Loading

0 comments on commit 527cb85

Please sign in to comment.