From 4d27bb5a99b1f88bb1cd6b7e2de0247b500c358c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 22 Aug 2024 15:33:11 -0400 Subject: [PATCH] support PBC Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/model/mace.py | 77 ++++++++++++++++++++++++---------- examples/water/mace/input.json | 60 ++++++++++++++++++++++++++ 2 files changed, 116 insertions(+), 21 deletions(-) create mode 100644 examples/water/mace/input.json diff --git a/deepmd/pt/model/model/mace.py b/deepmd/pt/model/model/mace.py index 3276507cb8..6f42014ca5 100644 --- a/deepmd/pt/model/model/mace.py +++ b/deepmd/pt/model/model/mace.py @@ -435,17 +435,38 @@ def forward_lower( model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(-3) return model_predict - def forward_lower_common( + def forward_lower_common( # noqa: PLR0915 self, - extended_coord, - extended_atype, - nlist, - mapping: Optional[torch.Tensor] = None, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, # noqa: ARG002 fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - do_atomic_virial: bool = False, - comm_dict: Optional[Dict[str, torch.Tensor]] = None, - ): + do_atomic_virial: bool = False, # noqa: ARG002 + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + """Forward lower common pass of the model. + + Parameters + ---------- + extended_coord : torch.Tensor + The extended coordinates of atoms. + extended_atype : torch.Tensor + The extended atomic types of atoms. + nlist : torch.Tensor + The neighbor list. + mapping : torch.Tensor, optional + The mapping tensor. + fparam : torch.Tensor, optional + The frame parameters. + aparam : torch.Tensor, optional + The atomic parameters. + do_atomic_virial : bool, optional + Whether to compute atomic virial. + comm_dict : dict[str, torch.Tensor], optional + The communication dictionary. + """ extended_coord_ = extended_coord nf, nall, _ = extended_coord_.shape _, nloc, _ = nlist.shape @@ -515,21 +536,35 @@ def forward_lower_common( device=extended_coord_ff.device, ), }, - compute_virials=True, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=True, training=self.training, ) - energy = ret["energy"] - assert energy is not None - energy = energy.view(1, 1).to(extended_coord_.dtype) - force = ret["forces"] - assert force is not None - force = force.view(1, nall, 3).to(extended_coord_.dtype) - virial = ret["virials"] - assert virial is not None - virial = virial.view(1, 9) + atom_energy = ret["node_energy"] - assert atom_energy is not None - atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nall] + if atom_energy is None: + msg = "atom_energy is None" + raise ValueError(msg) + atom_energy = atom_energy.view(1, nall).to(extended_coord_.dtype)[:, :nloc] + energy = torch.sum(atom_energy, dim=1).view(1, 1).to(extended_coord_.dtype) + grad_outputs: list[Optional[torch.Tensor]] = [ + torch.ones_like(energy), + ] + displacement = ret["displacement"] + force, virial = torch.autograd.grad( + outputs=[energy], + inputs=[extended_coord_ff, displacement], + grad_outputs=grad_outputs, + retain_graph=True, + create_graph=self.training, + ) + force = -force + virial = -virial + force = force.view(1, nall, 3).to(extended_coord_.dtype) + virial = virial.view(1, 9).to(extended_coord_.dtype) + energies.append(energy) forces.append(force) virials.append(virial) @@ -544,7 +579,7 @@ def forward_lower_common( "energy_derv_r": forces.view(nf, nall, 1, 3), "energy_derv_c_redu": virials.view(nf, 1, 9), # take the first nloc atoms to match other models - "energy": atom_energies.view(nf, nall, 1)[:, :nloc, :], + "energy": atom_energies.view(nf, nloc, 1), # fake atom_virial "energy_derv_c": torch.zeros( (nf, nall, 1, 9), diff --git a/examples/water/mace/input.json b/examples/water/mace/input.json new file mode 100644 index 0000000000..c0f03ea82c --- /dev/null +++ b/examples/water/mace/input.json @@ -0,0 +1,60 @@ +{ + "_comment1": " model parameters", + "model": { + "type": "mace", + "type_map": [ + "O", + "H" + ], + "r_max": 6.0, + "sel": "auto", + "_comment2": " that's all" + }, + + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + "_comment5": "that's all" + }, + + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + "_comment6": " that's all" + }, + + "training": { + "training_data": { + "systems": [ + "../data/data_0/", + "../data/data_1/", + "../data/data_2/" + ], + "batch_size": "auto", + "_comment7": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 3, + "_comment8": "that's all" + }, + "numb_steps": 1000000, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000, + "_comment9": "that's all" + }, + + "_comment10": "that's all" +}