Skip to content

Commit

Permalink
support PBC
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Aug 22, 2024
1 parent 579aafe commit 4d27bb5
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 21 deletions.
77 changes: 56 additions & 21 deletions deepmd/pt/model/model/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,17 +435,38 @@ def forward_lower(
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3)
return model_predict

def forward_lower_common(
def forward_lower_common( # noqa: PLR0915
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None, # noqa: ARG002
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
do_atomic_virial: bool = False, # noqa: ARG002
comm_dict: Optional[dict[str, torch.Tensor]] = None,
) -> dict[str, torch.Tensor]:
"""Forward lower common pass of the model.
Parameters
----------
extended_coord : torch.Tensor
The extended coordinates of atoms.
extended_atype : torch.Tensor
The extended atomic types of atoms.
nlist : torch.Tensor
The neighbor list.
mapping : torch.Tensor, optional
The mapping tensor.
fparam : torch.Tensor, optional
The frame parameters.
aparam : torch.Tensor, optional
The atomic parameters.
do_atomic_virial : bool, optional
Whether to compute atomic virial.
comm_dict : dict[str, torch.Tensor], optional
The communication dictionary.
"""
extended_coord_ = extended_coord
nf, nall, _ = extended_coord_.shape
_, nloc, _ = nlist.shape
Expand Down Expand Up @@ -515,21 +536,35 @@ def forward_lower_common(
device=extended_coord_ff.device,
),
},
compute_virials=True,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=True,
training=self.training,
)
energy = ret["energy"]
assert energy is not None
energy = energy.view(1, 1).to(extended_coord_.dtype)
force = ret["forces"]
assert force is not None
force = force.view(1, nall, 3).to(extended_coord_.dtype)
virial = ret["virials"]
assert virial is not None
virial = virial.view(1, 9)

atom_energy = ret["node_energy"]
assert atom_energy is not None
atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nall]
if atom_energy is None:
msg = "atom_energy is None"
raise ValueError(msg)
atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nloc]
energy = torch.sum(atom_energy, dim=1).view(1, 1).to(extended_coord_.dtype)
grad_outputs: list[Optional[torch.Tensor]] = [
torch.ones_like(energy),
]
displacement = ret["displacement"]
force, virial = torch.autograd.grad(
outputs=[energy],
inputs=[extended_coord_ff, displacement],
grad_outputs=grad_outputs,
retain_graph=True,
create_graph=self.training,
)
force = -force
virial = -virial
force = force.view(1, nall, 3).to(extended_coord_.dtype)
virial = virial.view(1, 9).to(extended_coord_.dtype)

energies.append(energy)
forces.append(force)
virials.append(virial)
Expand All @@ -544,7 +579,7 @@ def forward_lower_common(
"energy_derv_r": forces.view(nf, nall, 1, 3),
"energy_derv_c_redu": virials.view(nf, 1, 9),
# take the first nloc atoms to match other models
"energy": atom_energies.view(nf, nall, 1)[:, :nloc, :],
"energy": atom_energies.view(nf, nloc, 1),
# fake atom_virial
"energy_derv_c": torch.zeros(
(nf, nall, 1, 9),
Expand Down
60 changes: 60 additions & 0 deletions examples/water/mace/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"_comment1": " model parameters",
"model": {
"type": "mace",
"type_map": [
"O",
"H"
],
"r_max": 6.0,
"sel": "auto",
"_comment2": " that's all"
},

"learning_rate": {
"type": "exp",
"decay_steps": 5000,
"start_lr": 0.001,
"stop_lr": 3.51e-8,
"_comment5": "that's all"
},

"loss": {
"type": "ener",
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0,
"_comment6": " that's all"
},

"training": {
"training_data": {
"systems": [
"../data/data_0/",
"../data/data_1/",
"../data/data_2/"
],
"batch_size": "auto",
"_comment7": "that's all"
},
"validation_data": {
"systems": [
"../data/data_3"
],
"batch_size": 1,
"numb_btch": 3,
"_comment8": "that's all"
},
"numb_steps": 1000000,
"seed": 10,
"disp_file": "lcurve.out",
"disp_freq": 100,
"save_freq": 1000,
"_comment9": "that's all"
},

"_comment10": "that's all"
}

0 comments on commit 4d27bb5

Please sign in to comment.