From 8fd9565ce03ded903b04ec6c43b1417c48ce6dad Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:49:40 +0800 Subject: [PATCH] feat(pt): support spin virial --- deepmd/pt/loss/ener_spin.py | 16 +++++++ deepmd/pt/model/model/make_model.py | 17 +++++++ deepmd/pt/model/model/spin_model.py | 45 +++++++++++++++---- deepmd/pt/model/model/transform_output.py | 7 +++ source/api_c/include/deepmd.hpp | 12 ++--- source/api_c/src/c_api.cc | 10 ++--- source/api_cc/src/DeepSpinPT.cc | 25 +++++------ source/tests/pt/model/test_autodiff.py | 17 ++++++- source/tests/pt/model/test_ener_spin_model.py | 3 +- .../universal/common/cases/model/utils.py | 12 +++-- source/tests/universal/pt/model/test_model.py | 2 + 11 files changed, 126 insertions(+), 40 deletions(-) diff --git a/deepmd/pt/loss/ener_spin.py b/deepmd/pt/loss/ener_spin.py index 6a926f4051..850f66bf1d 100644 --- a/deepmd/pt/loss/ener_spin.py +++ b/deepmd/pt/loss/ener_spin.py @@ -268,6 +268,22 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): rmse_ae.detach(), find_atom_ener ) + if self.has_v and "virial" in model_pred and "virial" in label: + find_virial = label.get("find_virial", 0.0) + pref_v = pref_v * find_virial + diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9) + l2_virial_loss = torch.mean(torch.square(diff_v)) + if not self.inference: + more_loss["l2_virial_loss"] = self.display_if_exist( + l2_virial_loss.detach(), find_virial + ) + loss += atom_norm * (pref_v * l2_virial_loss) + rmse_v = l2_virial_loss.sqrt() * atom_norm + more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial) + if mae: + mae_v = torch.mean(torch.abs(diff_v)) * atom_norm + more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial) + if not self.inference: more_loss["rmse"] = torch.sqrt(loss.detach()) return model_pred, loss, more_loss diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c32abaa095..2756c66252 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -135,6 +135,7 @@ def forward_common( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + coord_corr_for_virial: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """Return model prediction. @@ -153,6 +154,9 @@ def forward_common( atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. + coord_corr_for_virial + The coordinates correction of the atoms for virial. + shape: nf x (nloc x 3) Returns ------- @@ -180,6 +184,14 @@ def forward_common( mixed_types=True, box=bb, ) + if coord_corr_for_virial is not None: + coord_corr_for_virial = coord_corr_for_virial.to(cc.dtype) + extended_coord_corr = torch.gather( + coord_corr_for_virial, 1, mapping.unsqueeze(-1).expand(-1, -1, 3) + ) + else: + extended_coord_corr = None + model_predict_lower = self.forward_common_lower( extended_coord, extended_atype, @@ -188,6 +200,7 @@ def forward_common( do_atomic_virial=do_atomic_virial, fparam=fp, aparam=ap, + extended_coord_corr=extended_coord_corr, ) model_predict = communicate_extended_output( model_predict_lower, @@ -242,6 +255,7 @@ def forward_common_lower( do_atomic_virial: bool = False, comm_dict: Optional[dict[str, torch.Tensor]] = None, extra_nlist_sort: bool = False, + extended_coord_corr: Optional[torch.Tensor] = None, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -268,6 +282,8 @@ def forward_common_lower( The data needed for communication for parallel inference. extra_nlist_sort whether to forcibly sort the nlist. + extended_coord_corr + coordinates correction for virial in extended region. nf x (nall x 3) Returns ------- @@ -299,6 +315,7 @@ def forward_common_lower( cc_ext, do_atomic_virial=do_atomic_virial, create_graph=self.training, + extended_coord_corr=extended_coord_corr, ) model_predict = self.output_type_cast(model_predict, input_prec) return model_predict diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index ac94668039..a847a869ce 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -54,11 +54,14 @@ def process_spin_input(self, coord, atype, spin): coord = coord.reshape(nframes, nloc, 3) spin = spin.reshape(nframes, nloc, 3) atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1) - virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[ - atype - ].reshape([nframes, nloc, 1]) + spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape( + [nframes, nloc, 1] + ) + virtual_coord = coord + spin_dist coord_spin = torch.concat([coord, virtual_coord], dim=-2) - return coord_spin, atype_spin + # for spin virial corr + coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2) + return coord_spin, atype_spin, coord_corr def process_spin_input_lower( self, @@ -78,13 +81,18 @@ def process_spin_input_lower( """ nframes, nall = extended_coord.shape[:2] nloc = nlist.shape[1] - virtual_extended_coord = extended_coord + extended_spin * ( + extended_spin_dist = extended_spin * ( self.virtual_scale_mask.to(extended_atype.device) )[extended_atype].reshape([nframes, nall, 1]) + virtual_extended_coord = extended_coord + extended_spin_dist virtual_extended_atype = extended_atype + self.ntypes_real extended_coord_updated = concat_switch_virtual( extended_coord, virtual_extended_coord, nloc ) + # for spin virial corr + extended_coord_corr = concat_switch_virtual( + torch.zeros_like(extended_coord), -extended_spin_dist, nloc + ) extended_atype_updated = concat_switch_virtual( extended_atype, virtual_extended_atype, nloc ) @@ -100,6 +108,7 @@ def process_spin_input_lower( extended_atype_updated, nlist_updated, mapping_updated, + extended_coord_corr, ) def process_spin_output( @@ -367,7 +376,7 @@ def spin_sampled_func(): sampled = sampled_func() spin_sampled = [] for sys in sampled: - coord_updated, atype_updated = self.process_spin_input( + coord_updated, atype_updated, _ = self.process_spin_input( sys["coord"], sys["atype"], sys["spin"] ) tmp_dict = { @@ -398,7 +407,9 @@ def forward_common( do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: nframes, nloc = atype.shape - coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) + coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input( + coord, atype, spin + ) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) model_ret = self.backbone_model.forward_common( @@ -408,6 +419,7 @@ def forward_common( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + coord_corr_for_virial=coord_corr_for_virial, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -454,6 +466,7 @@ def forward_common_lower( extended_atype_updated, nlist_updated, mapping_updated, + extended_coord_corr_for_virial, ) = self.process_spin_input_lower( extended_coord, extended_atype, extended_spin, nlist, mapping=mapping ) @@ -469,6 +482,7 @@ def forward_common_lower( do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr_for_virial, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -541,6 +555,11 @@ def translated_output_def(self): output_def["force"].squeeze(-2) output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"]) output_def["force_mag"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-3) return output_def def forward( @@ -569,7 +588,10 @@ def forward( if self.backbone_model.do_grad_r("energy"): model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) - # not support virial by far + if self.backbone_model.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) return model_predict @torch.jit.export @@ -606,5 +628,10 @@ def forward_lower( model_predict["extended_force_mag"] = model_ret[ "energy_derv_r_mag" ].squeeze(-2) - # not support virial by far + if self.backbone_model.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) return model_predict diff --git a/deepmd/pt/model/model/transform_output.py b/deepmd/pt/model/model/transform_output.py index e15eda6a1d..fcd41e075c 100644 --- a/deepmd/pt/model/model/transform_output.py +++ b/deepmd/pt/model/model/transform_output.py @@ -156,6 +156,7 @@ def fit_output_to_model_output( coord_ext: torch.Tensor, do_atomic_virial: bool = False, create_graph: bool = True, + extended_coord_corr: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -187,6 +188,12 @@ def fit_output_to_model_output( model_ret[kk_derv_r] = dr if vdef.c_differentiable: assert dc is not None + if extended_coord_corr is not None: + dc_corr = ( + dr.squeeze(-2).unsqueeze(-1) + @ extended_coord_corr.unsqueeze(-2) + ).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005 + dc = dc + dc_corr model_ret[kk_derv_c] = dc model_ret[kk_derv_c + "_redu"] = torch.sum( model_ret[kk_derv_c].to(redu_prec), dim=1 diff --git a/source/api_c/include/deepmd.hpp b/source/api_c/include/deepmd.hpp index 8a3656bfc2..a37fe10fa9 100644 --- a/source/api_c/include/deepmd.hpp +++ b/source/api_c/include/deepmd.hpp @@ -2602,9 +2602,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi { for (int j = 0; j < natoms * 3; j++) { force_mag[i][j] = force_mag_flat[i * natoms * 3 + j]; } - // for (int j = 0; j < 9; j++) { - // virial[i][j] = virial_flat[i * 9 + j]; - // } + for (int j = 0; j < 9; j++) { + virial[i][j] = virial_flat[i * 9 + j]; + } } }; /** @@ -2705,9 +2705,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi { for (int j = 0; j < natoms * 3; j++) { force_mag[i][j] = force_mag_flat[i * natoms * 3 + j]; } - // for (int j = 0; j < 9; j++) { - // virial[i][j] = virial_flat[i * 9 + j]; - // } + for (int j = 0; j < 9; j++) { + virial[i][j] = virial_flat[i * 9 + j]; + } for (int j = 0; j < natoms; j++) { atom_energy[i][j] = atom_energy_flat[i * natoms + j]; } diff --git a/source/api_c/src/c_api.cc b/source/api_c/src/c_api.cc index 4a0cff1520..3acb28a002 100644 --- a/source/api_c/src/c_api.cc +++ b/source/api_c/src/c_api.cc @@ -862,11 +862,11 @@ void DP_DeepSpinModelDeviCompute_variant(DP_DeepSpinModelDevi* dp, flatten_vector(fm_flat, fm); std::copy(fm_flat.begin(), fm_flat.end(), force_mag); } - // if (virial) { - // std::vector v_flat; - // flatten_vector(v_flat, v); - // std::copy(v_flat.begin(), v_flat.end(), virial); - // } + if (virial) { + std::vector v_flat; + flatten_vector(v_flat, v); + std::copy(v_flat.begin(), v_flat.end(), virial); + } if (atomic_energy) { std::vector ae_flat; flatten_vector(ae_flat, ae); diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 7421b623db..eb43dbf6d0 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -251,8 +251,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("extended_force"); c10::IValue force_mag_ = outputs.at("extended_force_mag"); - // spin model not suported yet - // c10::IValue virial_ = outputs.at("virial"); + c10::IValue virial_ = outputs.at("virial"); torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); ener.assign(cpu_energy_.data_ptr(), @@ -267,11 +266,11 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, dforce_mag.assign( cpu_force_mag_.data_ptr(), cpu_force_mag_.data_ptr() + cpu_force_mag_.numel()); - // spin model not suported yet - // torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); - // torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); - // virial.assign(cpu_virial_.data_ptr(), - // cpu_virial_.data_ptr() + cpu_virial_.numel()); + + torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); // bkw map force.resize(static_cast(nframes) * fwd_map.size() * 3); @@ -415,8 +414,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("force"); c10::IValue force_mag_ = outputs.at("force_mag"); - // spin model not suported yet - // c10::IValue virial_ = outputs.at("virial"); + c10::IValue virial_ = outputs.at("virial"); torch::Tensor flat_energy_ = energy_.toTensor().view({-1}); torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU); ener.assign(cpu_energy_.data_ptr(), @@ -431,11 +429,10 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, force_mag.assign( cpu_force_mag_.data_ptr(), cpu_force_mag_.data_ptr() + cpu_force_mag_.numel()); - // spin model not suported yet - // torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); - // torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); - // virial.assign(cpu_virial_.data_ptr(), - // cpu_virial_.data_ptr() + cpu_virial_.numel()); + torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType); + torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU); + virial.assign(cpu_virial_.data_ptr(), + cpu_virial_.data_ptr() + cpu_virial_.numel()); if (atomic) { // c10::IValue atom_virial_ = outputs.at("atom_virial"); c10::IValue atom_energy_ = outputs.at("atom_energy"); diff --git a/source/tests/pt/model/test_autodiff.py b/source/tests/pt/model/test_autodiff.py index 31e06af751..fab637f0f8 100644 --- a/source/tests/pt/model/test_autodiff.py +++ b/source/tests/pt/model/test_autodiff.py @@ -141,11 +141,17 @@ def test( cell = (cell) + 5.0 * torch.eye(3, device="cpu") coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) coord = torch.matmul(coord, cell) + spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) atype = torch.IntTensor([0, 0, 0, 1, 1]) # assumes input to be numpy tensor coord = coord.numpy() + spin = spin.numpy() cell = cell.numpy() - test_keys = ["energy", "force", "virial"] + test_spin = getattr(self, "test_spin", False) + if not test_spin: + test_keys = ["energy", "force", "virial"] + else: + test_keys = ["energy", "force", "force_mag", "virial"] def np_infer( new_cell, @@ -157,6 +163,7 @@ def np_infer( ).unsqueeze(0), torch.tensor(new_cell, device="cpu").unsqueeze(0), atype, + spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0), ) # detach ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys} @@ -251,3 +258,11 @@ def setUp(self) -> None: self.type_split = False self.test_spin = True self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: + model_params = copy.deepcopy(model_spin) + self.type_split = False + self.test_spin = True + self.model = get_model(model_params).to(env.DEVICE) diff --git a/source/tests/pt/model/test_ener_spin_model.py b/source/tests/pt/model/test_ener_spin_model.py index ddea392f33..66bb1082a0 100644 --- a/source/tests/pt/model/test_ener_spin_model.py +++ b/source/tests/pt/model/test_ener_spin_model.py @@ -115,7 +115,7 @@ def test_input_output_process(self) -> None: nframes, nloc = self.coord.shape[:2] self.real_ntypes = self.model.spin.get_ntypes_real() # 1. test forward input process - coord_updated, atype_updated = self.model.process_spin_input( + coord_updated, atype_updated, _ = self.model.process_spin_input( self.coord, self.atype, self.spin ) # compare atypes of real and virtual atoms @@ -174,6 +174,7 @@ def test_input_output_process(self) -> None: extended_atype_updated, nlist_updated, mapping_updated, + _, ) = self.model.process_spin_input_lower( extended_coord, extended_atype, extended_spin, nlist, mapping=mapping ) diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index 8fe6a131ef..e2a1b4866a 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -892,7 +892,10 @@ def ff_spin(_spin): fdf.reshape(-1, 3), rff.reshape(-1, 3), decimal=places ) - if not test_spin: + # this option can be removed after other backends support spin virial + test_spin_virial = getattr(self, "test_spin_virial", False) + + if not test_spin or test_spin_virial: def ff_cell(bb): input_dict = { @@ -902,6 +905,8 @@ def ff_cell(bb): "aparam": aparam, "fparam": fparam, } + if test_spin: + input_dict["spin"] = spin return module(**input_dict)["energy"] fdv = ( @@ -921,13 +926,12 @@ def ff_cell(bb): "aparam": aparam, "fparam": fparam, } + if test_spin: + input_dict["spin"] = spin rfv = module(**input_dict)["virial"] np.testing.assert_almost_equal( fdv.reshape(-1, 9), rfv.reshape(-1, 9), decimal=places ) - else: - # not support virial by far - pass @unittest.skipIf(TEST_DEVICE == "cpu" and CI, "Skip test on CPU.") def test_device_consistence(self) -> None: diff --git a/source/tests/universal/pt/model/test_model.py b/source/tests/universal/pt/model/test_model.py index 3eb1484c45..ec6cd71782 100644 --- a/source/tests/universal/pt/model/test_model.py +++ b/source/tests/universal/pt/model/test_model.py @@ -713,6 +713,8 @@ def setUpClass(cls) -> None: cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() cls.expected_dim_aparam = ft.get_dim_aparam() + # this option can be removed after other backends support spin virial + cls.test_spin_virial = True @parameterized(