diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index b9c4971116..450b878083 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -46,6 +46,7 @@ ) from deepmd.pt.utils.env import ( DEVICE, + load_op, ) from deepmd.pt.utils.finetune import ( change_finetune_model_params, @@ -299,7 +300,7 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None): set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None) log.debug("Log handles were successfully set") log.info("DeepMD version: %s", __version__) - + load_op() if FLAGS.command == "train": train(FLAGS) elif FLAGS.command == "freeze": diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index c921538203..55a5797dab 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -105,6 +105,7 @@ def forward_common_atomic( mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Common interface for atomic inference. @@ -153,6 +154,7 @@ def forward_common_atomic( mapping=mapping, fparam=fparam, aparam=aparam, + comm_dict=comm_dict, ) # nf x nloc diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 6aa8df7aee..1f6eb146cf 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -134,6 +134,7 @@ def forward_atomic( mapping: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> Dict[str, torch.Tensor]: """Return atomic prediction. @@ -163,10 +164,7 @@ def forward_atomic( if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) descriptor, rot_mat, g2, h2, sw = self.descriptor( - extended_coord, - extended_atype, - nlist, - mapping=mapping, + extended_coord, extended_atype, nlist, mapping=mapping, comm_dict=comm_dict ) assert descriptor is not None # energy, force diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index fb792a51e2..443f07fcd6 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Callable, + Dict, List, Optional, Tuple, @@ -395,6 +396,7 @@ def forward( extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): """Compute the descriptor. @@ -450,11 +452,12 @@ def forward( # linear to change shape g1 = self.g1_shape_tranform(g1) # mapping g1 - assert mapping is not None - mapping_ext = ( - mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1]) - ) - g1_ext = torch.gather(g1, 1, mapping_ext) + if comm_dict is None: + assert mapping is not None + # mapping_ext = ( + # mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1]) + # ) + # g1_ext = torch.gather(g1, 1, mapping_ext) # repformer g1, g2, h2, rot_mat, sw = self.repformers( nlist_dict[ @@ -464,8 +467,9 @@ def forward( ], extended_coord, extended_atype, - g1_ext, + g1, mapping, + comm_dict, ) if self.concat_output_tebd: g1 = torch.cat([g1, g1_inp], dim=-1) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index a908d2e057..adaea32b8a 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -227,6 +227,7 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + @torch.jit.script_method def forward( self, nlist: torch.Tensor, @@ -234,9 +235,10 @@ def forward( extended_atype: torch.Tensor, extended_atype_embd: Optional[torch.Tensor] = None, mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): - assert mapping is not None - assert extended_atype_embd is not None + if comm_dict is None: + assert extended_atype_embd is not None nframes, nloc, nnei = nlist.shape nall = extended_coord.view(nframes, -1).shape[1] // 3 atype = extended_atype[:, :nloc] @@ -257,8 +259,12 @@ def forward( sw = sw.masked_fill(~nlist_mask, 0.0) # [nframes, nloc, tebd_dim] - atype_embd = extended_atype_embd[:, :nloc, :] - assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] + # atype_embd = extended_atype_embd[:, :nloc, :] + atype_embd = extended_atype_embd + if atype_embd is not None: + assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim] + else: + raise NotImplementedError g1 = self.act(atype_embd) # nb x nloc x nnei x 1, nb x nloc x nnei x 3 @@ -275,11 +281,43 @@ def forward( # if the a neighbor is real or not is indicated by nlist_mask nlist[nlist == -1] = 0 # nb x nall x ng1 - mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) + if comm_dict is None: + assert mapping is not None + mapping = ( + mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim) + ) for idx, ll in enumerate(self.layers): # g1: nb x nloc x ng1 # g1_ext: nb x nall x ng1 - g1_ext = torch.gather(g1, 1, mapping) + if comm_dict is None: + assert mapping is not None + g1_ext = torch.gather(g1, 1, mapping) + else: + # padding = torch.zeros(nall-nloc, g1.size(2),device=mydev) + # g1 = torch.cat((g1.squeeze(0), padding), dim=0) + n_padding = nall - nloc + g1 = torch.nn.functional.pad( + g1.squeeze(0), (0, 0, 0, n_padding), value=0.0 + ) + assert "send_list" in comm_dict + assert "send_proc" in comm_dict + assert "recv_proc" in comm_dict + assert "send_num" in comm_dict + assert "recv_num" in comm_dict + assert "communicator" in comm_dict + ret = env.op_module.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + g1, + comm_dict["communicator"], + torch.tensor(nloc), + torch.tensor(nall - nloc), + ) + g1_ext = ret[0].unsqueeze(0) + g1, g2, h2 = ll.forward( g1_ext, g2, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index e17b7c5d54..d77f7c336c 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -191,6 +191,7 @@ def forward( atype_ext: torch.Tensor, nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): """Compute the descriptor. diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 5217293623..eba738271b 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -6,6 +6,7 @@ import torch + from .dp_model import ( DPModel, ) @@ -69,6 +70,7 @@ def forward_lower( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): model_ret = self.forward_common_lower( extended_coord, @@ -78,6 +80,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 167ad81923..bc5847f00a 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -181,6 +181,7 @@ def forward_common_lower( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + comm_dict: Optional[Dict[str, torch.Tensor]] = None, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -224,6 +225,7 @@ def forward_common_lower( mapping=mapping, fparam=fp, aparam=ap, + comm_dict=comm_dict, ) model_predict = fit_output_to_model_output( atomic_ret, diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 0b92953255..ab9d546349 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import os +from typing import ( + Any, +) import numpy as np import torch @@ -80,3 +83,12 @@ "ENERGY_BIAS_TRAINABLE", "LOCAL_RANK", ] + + +def load_op(): + torch.ops.load_library( + "/mnt/user/zhangxiangyu/workspace/dpkit/deepmd-kit/source/op_pt/libop_pt.so" + ) + + +op_module: Any = torch.ops.my_ops diff --git a/source/api_c/include/c_api.h b/source/api_c/include/c_api.h index 911813e428..254890dda8 100644 --- a/source/api_c/include/c_api.h +++ b/source/api_c/include/c_api.h @@ -25,6 +25,19 @@ extern DP_Nlist* DP_NewNlist(int inum_, int* numneigh_, int** firstneigh_); +extern DP_Nlist* DP_NewNlist_comm(int inum_, + int* ilist_, + int* numneigh_, + int** firstneigh_, + int nswap, + int* sendnum, + int* recvnum, + int* firstrecv, + int** sendlist, + int* sendproc, + int* recvproc, + int world); + /** * @brief Delete a neighbor list. * diff --git a/source/api_c/include/deepmd.hpp b/source/api_c/include/deepmd.hpp index 16b8f08cad..8d284bf9b8 100644 --- a/source/api_c/include/deepmd.hpp +++ b/source/api_c/include/deepmd.hpp @@ -502,6 +502,20 @@ inline double *_DP_Get_Energy_Pointer(double &vec, const int nframes) { namespace deepmd { namespace hpp { +// struct CommData { +// int nswap; +// int* sendnum; +// int* recvnum; +// int* firstrecv; +// int** sendlist; +// int* sendproc; +// int* recvproc; +// long int* world; + +// CommData() : nswap(0), sendnum(nullptr), recvnum(nullptr), +// firstrecv(nullptr), sendlist(nullptr), +// sendproc(nullptr), recvproc(nullptr),world(nullptr) {} +// }; /** * @brief Neighbor list. **/ @@ -522,6 +536,36 @@ struct InputNlist { nl(DP_NewNlist(inum_, ilist_, numneigh_, firstneigh_)) { DP_CHECK_OK(DP_NlistCheckOK, nl); }; + InputNlist(int inum_, + int *ilist_, + int *numneigh_, + int **firstneigh_, + int nswap, + int *sendnum, + int *recvnum, + int *firstrecv, + int **sendlist, + int *sendproc, + int *recvproc, + int world) + : inum(inum_), + ilist(ilist_), + numneigh(numneigh_), + firstneigh(firstneigh_), + nl(DP_NewNlist_comm(inum_, + ilist_, + numneigh_, + firstneigh_, + nswap, + sendnum, + recvnum, + firstrecv, + sendlist, + sendproc, + recvproc, + world)){ + // DP_CHECK_OK(DP_NlistCheckOK, nl); + }; ~InputNlist() { DP_DeleteNlist(nl); }; /// @brief C API neighbor list. DP_Nlist *nl; @@ -798,7 +842,6 @@ class DeepPot { aparam); const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr; const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr; - _DP_DeepPotComputeNList( dp, nframes, natoms, coord_, atype_, box_, nghost, lmp_list.nl, ago, fparam__, aparam__, ener_, force_, virial_, nullptr, nullptr); @@ -2117,6 +2160,5 @@ void select_map(std::vector &out, out.resize(static_cast(nall2) * stride); DP_SelectMapInt(&in[0], &fwd_map[0], stride, nall1, nall2, &out[0]); }; - } // namespace hpp } // namespace deepmd diff --git a/source/api_c/src/c_api.cc b/source/api_c/src/c_api.cc index 79dc486e0d..0d5c3694fc 100644 --- a/source/api_c/src/c_api.cc +++ b/source/api_c/src/c_api.cc @@ -24,6 +24,24 @@ DP_Nlist* DP_NewNlist(int inum_, deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_); DP_Nlist* new_nl = new DP_Nlist(nl); return new_nl;) } +DP_Nlist* DP_NewNlist_comm(int inum_, + int* ilist_, + int* numneigh_, + int** firstneigh_, + int nswap, + int* sendnum, + int* recvnum, + int* firstrecv, + int** sendlist, + int* sendproc, + int* recvproc, + int world) { + deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, nswap, sendnum, + recvnum, firstrecv, sendlist, sendproc, recvproc, + world); + DP_Nlist* new_nl = new DP_Nlist(nl); + return new_nl; +} void DP_DeleteNlist(DP_Nlist* nl) { delete nl; } @@ -268,7 +286,6 @@ inline void DP_DeepPotComputeNList_variant(DP_DeepPot* dp, } std::vector e; std::vector f, v, ae, av; - DP_REQUIRES_OK(dp, dp->dp.compute(e, f, v, ae, av, coord_, atype_, cell_, nghost, nlist->nl, ago, fparam_, aparam_)); // copy from C++ vectors to C arrays, if not NULL pointer diff --git a/source/api_cc/include/DeepPotPT.h b/source/api_cc/include/DeepPotPT.h index a7fc910b46..0b61d1a2a2 100644 --- a/source/api_cc/include/DeepPotPT.h +++ b/source/api_cc/include/DeepPotPT.h @@ -327,6 +327,7 @@ class DeepPotPT : public DeepPotBase { int gpu_id; bool gpu_enabled; at::Tensor firstneigh_tensor; + torch::Dict comm_dict; }; } // namespace deepmd diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 2010780a6c..2901e9aafc 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -13,6 +13,7 @@ namespace deepmd { typedef double ENERGYTYPE; +// TODO: currently we only implement TF&PT; reserve for future use enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown }; struct NeighborListData { diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 2c3fd1d865..70f1daaff1 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -105,8 +105,10 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, options = torch::TensorOptions().dtype(torch::kFloat32); floatType = torch::kFloat32; } - auto int_options = torch::TensorOptions().dtype(torch::kInt64); - auto int32_options = torch::TensorOptions().dtype(torch::kInt32); + auto int32_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt32); + auto int_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64); // select real atoms std::vector dcoord, dforce, aparam_, datom_energy, datom_virial; @@ -122,11 +124,39 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, .to(device); std::vector atype_64(datype.begin(), datype.end()); at::Tensor atype_Tensor = - torch::from_blob(atype_64.data(), {1, nall_real}, int_options).to(device); + torch::from_blob(atype_64.data(), {1, nall_real}, int_option).to(device); if (ago == 0) { nlist_data.copy_from_nlist(lmp_list); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); + + int nswap = lmp_list.nswap; + torch::Tensor sendproc_tensor = + torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); + torch::Tensor recvproc_tensor = + torch::from_blob(lmp_list.recvproc, {nswap}, int32_option); + torch::Tensor firstrecv_tensor = + torch::from_blob(lmp_list.firstrecv, {nswap}, int32_option); + torch::Tensor recvnum_tensor = + torch::from_blob(lmp_list.recvnum, {nswap}, int32_option); + torch::Tensor sendnum_tensor = + torch::from_blob(lmp_list.sendnum, {nswap}, int32_option); + // torch::Tensor communicator_tensor = + // torch::from_blob(lmp_list.commdata->world, {1}, int_option); + torch::Tensor communicator_tensor = + torch::tensor(lmp_list.world, int32_option); + torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option); + int total_send = + std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0); + torch::Tensor sendlist_tensor = + torch::from_blob(lmp_list.sendlist, {total_send}, int32_option); + + comm_dict.insert("send_list", sendlist_tensor); + comm_dict.insert("send_proc", sendproc_tensor); + comm_dict.insert("recv_proc", recvproc_tensor); + comm_dict.insert("send_num", sendnum_tensor); + comm_dict.insert("recv_num", recvnum_tensor); + comm_dict.insert("communicator", communicator_tensor); } at::Tensor firstneigh = createNlistTensor(nlist_data.jlist); firstneigh_tensor = firstneigh.to(torch::kInt64).to(device); @@ -152,7 +182,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener, module .run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor, firstneigh_tensor, optional_tensor, fparam_tensor, - aparam_tensor, do_atom_virial_tensor) + aparam_tensor, do_atom_virial_tensor, comm_dict) .toGenericDict(); c10::IValue energy_ = outputs.at("energy"); c10::IValue force_ = outputs.at("extended_force"); diff --git a/source/lib/include/neighbor_list.h b/source/lib/include/neighbor_list.h index eb510eb25b..67e5ea80b1 100644 --- a/source/lib/include/neighbor_list.h +++ b/source/lib/include/neighbor_list.h @@ -11,6 +11,42 @@ #include "utilities.h" namespace deepmd { +// struct CommData { +// int nswap; +// int* sendnum; +// int* recvnum; +// int* firstrecv; +// int** sendlist; +// int* sendproc; +// int* recvproc; +// long int world; + +// CommData() +// : nswap(0), +// sendnum(nullptr), +// recvnum(nullptr), +// firstrecv(nullptr), +// sendlist(nullptr), +// sendproc(nullptr), +// recvproc(nullptr), +// world(0){}; +// CommData(int nswap, +// int* sendnum, +// int* recvnum, +// int* firstrecv, +// int** sendlist, +// int* sendproc, +// int* recvproc, +// long int world) +// : nswap(nswap), +// sendnum(sendnum), +// recvnum(recvnum), +// firstrecv(firstrecv), +// sendlist(sendlist), +// sendproc(sendproc), +// recvproc(recvproc), +// world(world) {} +// }; /** * @brief Construct InputNlist with the input LAMMPS nbor list info. @@ -26,12 +62,65 @@ struct InputNlist { int* numneigh; /// Array stores the core region atom's neighbor index int** firstneigh; - InputNlist() : inum(0), ilist(NULL), numneigh(NULL), firstneigh(NULL){}; + + int nswap; + int* sendnum; + int* recvnum; + int* firstrecv; + int** sendlist; + int* sendproc; + int* recvproc; + int world; + InputNlist() + : inum(0), + ilist(NULL), + numneigh(NULL), + firstneigh(NULL), + nswap(0), + sendnum(nullptr), + recvnum(nullptr), + firstrecv(nullptr), + sendlist(nullptr), + sendproc(nullptr), + recvproc(nullptr), + world(0){}; InputNlist(int inum_, int* ilist_, int* numneigh_, int** firstneigh_) : inum(inum_), ilist(ilist_), numneigh(numneigh_), - firstneigh(firstneigh_){}; + firstneigh(firstneigh_), + nswap(0), + sendnum(nullptr), + recvnum(nullptr), + firstrecv(nullptr), + sendlist(nullptr), + sendproc(nullptr), + recvproc(nullptr), + world(0){}; + InputNlist(int inum_, + int* ilist_, + int* numneigh_, + int** firstneigh_, + int nswap, + int* sendnum, + int* recvnum, + int* firstrecv, + int** sendlist, + int* sendproc, + int* recvproc, + int world) + : inum(inum_), + ilist(ilist_), + numneigh(numneigh_), + firstneigh(firstneigh_), + nswap(nswap), + sendnum(sendnum), + recvnum(recvnum), + firstrecv(firstrecv), + sendlist(sendlist), + sendproc(sendproc), + recvproc(recvproc), + world(world){}; ~InputNlist(){}; }; diff --git a/source/lmp/builtin.cmake b/source/lmp/builtin.cmake index f29e9d3319..3828a57f4c 100644 --- a/source/lmp/builtin.cmake +++ b/source/lmp/builtin.cmake @@ -24,6 +24,12 @@ target_sources( ${LAMMPS_SOURCE_DIR}/EXTRA-FIX/fix_ttm.cpp # for ttm ) target_link_libraries(lammps PUBLIC DeePMD::deepmd_c) +target_link_libraries( + lammps + PUBLIC + -Wl,--no-as-needed + "/mnt/user/zhangxiangyu/workspace/dpkit/deepmd-kit/source/op_pt/libop_pt.so" + "/home/zhangxiangyu/.conda/envs/dp-cxxabi/lib/libmpi.so") target_include_directories( lammps PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_LIST_DIR} ${LAMMPS_SOURCE_DIR}/KSPACE ${LAMMPS_SOURCE_DIR}/EXTRA-FIX) diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 90aa453143..7aa1bdbe41 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: LGPL-3.0-or-later #include +#include #include #include #include @@ -471,6 +472,19 @@ void PairDeepMD::compute(int eflag, int vflag) { int nall = nlocal + nghost; int newton_pair = force->newton_pair; + // for dpa2 communication + // deepmd_compat::CommData* commdata = new deepmd_compat::CommData(); + // commdata->nswap = cb->nswap; + // commdata->sendnum = cb->sendnum; // dim: nswap + // commdata->recvnum = cb->recvnum; // dim: nswap + // commdata->firstrecv = cb->firstrecv; // dim: nswap + // commdata->sendlist = cb->sendlist; // dim: nswap x sendnum[nswap] + // commdata->sendproc = cb->sendproc; // dim: nswap + // commdata->recvproc = cb->recvproc; // dim: nswap + assert(sizeof(MPI_Comm) == sizeof(int)); + // std::cout<<"world:"<prd; vector dspin(nall * 3, 0.); vector dfm(nall * 3, 0.); double **sp = atom->sp; @@ -550,8 +564,15 @@ void PairDeepMD::compute(int eflag, int vflag) { multi_models_mod_devi = (numb_models > 1 && (out_freq > 0 && update->ntimestep % out_freq == 0)); if (do_ghost) { - deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh, - list->firstneigh); + deepmd_compat::InputNlist lmp_list( + list->inum, list->ilist, list->numneigh, list->firstneigh, + commdata_->nswap, commdata_->sendnum, commdata_->recvnum, + commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc, + commdata_->recvproc, world_int); + // else + // deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, + // list->numneigh, + // list->firstneigh); deepmd_compat::InputNlist extend_lmp_list; if (atom->sp_flag) { extend(extend_inum, extend_ilist, extend_numneigh, extend_neigh, @@ -1275,11 +1296,16 @@ void PairDeepMD::coeff(int narg, char **arg) { } } } + + // dpa2 communication + commdata_ = (CommBrickDeepMD *)comm; } void PairDeepMD::init_style() { #if LAMMPS_VERSION_NUMBER >= 20220324 neighbor->add_request(this, NeighConst::REQ_FULL); + atom->map_user = 2; + atom->map_init(1); #else int irequest = neighbor->request(this, instance_me); neighbor->requests[irequest]->half = 0; diff --git a/source/lmp/pair_deepmd.h b/source/lmp/pair_deepmd.h index cd72dc7b2a..bb4e48cfa1 100644 --- a/source/lmp/pair_deepmd.h +++ b/source/lmp/pair_deepmd.h @@ -32,10 +32,14 @@ namespace deepmd_compat = deepmd::hpp; #include #include +#include "comm_brick.h" + #define FLOAT_PREC double namespace LAMMPS_NS { - +class CommBrickDeepMD : public CommBrick { + friend class PairDeepMD; +}; class PairDeepMD : public Pair { public: PairDeepMD(class LAMMPS *); @@ -137,6 +141,7 @@ class PairDeepMD : public Pair { tagint *tagsend, *tagrecv; double *stdfsend, *stdfrecv; std::vector type_idx_map; + CommBrickDeepMD *commdata_; }; } // namespace LAMMPS_NS diff --git a/source/lmp/plugin/CMakeLists.txt b/source/lmp/plugin/CMakeLists.txt index 4fdae7ac5b..cfe29b5867 100644 --- a/source/lmp/plugin/CMakeLists.txt +++ b/source/lmp/plugin/CMakeLists.txt @@ -85,6 +85,12 @@ if(DEFINED LAMMPS_SOURCE_ROOT OR DEFINED LAMMPS_VERSION) target_compile_definitions(${libname} PUBLIC "DP_USE_CXX_API") endif() target_link_libraries(${libname} PUBLIC lammps_interface) + target_link_libraries( + ${libname} + PUBLIC + -Wl,--no-as-needed + "/mnt/user/zhangxiangyu/workspace/dpkit/deepmd-kit/source/op_pt/libop_pt.so" + "/home/zhangxiangyu/.conda/envs/dp-cxxabi/lib/libmpi.so") target_include_directories( ${libname} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/.. diff --git a/source/op_pt/CMakeLists.txt b/source/op_pt/CMakeLists.txt new file mode 100644 index 0000000000..57f566b6e7 --- /dev/null +++ b/source/op_pt/CMakeLists.txt @@ -0,0 +1,23 @@ +# libop +cmake_minimum_required(VERSION 3.12 FATAL_ERROR) +project(op) +set(GLIBCXX_USE_CXX11_ABI 1) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +find_package(Python REQUIRED COMPONENTS Development) +find_package(Torch REQUIRED) +# find_package(MPI REQUIRED) +find_package(CUDA REQUIRED) +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES 80) +endif() +option(USE_CUDA "GPU_SUPPORT" ON) +if(USE_CUDA) + add_definitions(-DUSE_CUDA) +endif() + +add_library(op_pt SHARED comm.cc) +target_include_directories(op_pt PRIVATE ${Python_INCLUDE_DIRS}) +target_link_libraries(op_pt "${TORCH_LIBRARIES}" "${Python_LIBRARIES}" + "/home/zhangxiangyu/.conda/envs/dp-cxxabi/lib/libmpi.so") diff --git a/source/op_pt/comm.cc b/source/op_pt/comm.cc new file mode 100644 index 0000000000..7cfd578ec6 --- /dev/null +++ b/source/op_pt/comm.cc @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include +#include + +#include "custom_op.h" +template +static MPI_Datatype get_mpi_type(); + +template <> +MPI_Datatype get_mpi_type() { + return MPI_FLOAT; +} + +template <> +MPI_Datatype get_mpi_type() { + return MPI_DOUBLE; +} + +class Border : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& g1, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + using FPTYPE = double; + ctx->save_for_backward({sendlist_tensor, sendproc_tensor, recvproc_tensor, + sendnum_tensor, recvnum_tensor, communicator_tensor, + nlocal_tensor, nghost_tensor}); + int** sendlist = reinterpret_cast(sendlist_tensor.data_ptr()); + int* sendproc = sendproc_tensor.data_ptr(); + int* recvproc = recvproc_tensor.data_ptr(); + int* sendnum = sendnum_tensor.data_ptr(); + int* recvnum = recvnum_tensor.data_ptr(); + int tensor_size = g1.size(1); + int nswap = sendproc_tensor.size(0); + + int nlocal = nlocal_tensor.item(); + int nghost = nghost_tensor.item(); + int ntotal = nlocal + nghost; + torch::Tensor recv_g1_tensor = g1; + + FPTYPE* recv_g1 = recv_g1_tensor.data_ptr() + nlocal * tensor_size; + + int me; + MPI_Comm_rank(MPI_COMM_WORLD, &me); + MPI_Comm world; + unpack_communicator(communicator_tensor, world); + MPI_Datatype mpi_type = get_mpi_type(); + MPI_Request request; + auto int32_options = torch::TensorOptions().dtype(torch::kInt32); + std::cout << "nswap: " << nswap << std::endl; + for (int iswap = 0; iswap < nswap; ++iswap) { + std::cout << "num" << iswap << std::endl; + int nrecv = recvnum[iswap]; + int nsend = sendnum[iswap]; + torch::Tensor isendlist = + torch::from_blob(sendlist[iswap], {nsend}, int32_options) + .to(recv_g1_tensor.device()); + torch::Tensor send_g1_tensor = recv_g1_tensor.index_select(0, isendlist); + FPTYPE* send_g1 = send_g1_tensor.data_ptr(); + if (sendproc[iswap] != me) { + if (nrecv) { + std::cout << "recv" << std::endl; + MPI_Irecv(recv_g1, nrecv * tensor_size, mpi_type, recvproc[iswap], 0, + world, &request); + } + if (nsend) { + std::cout << "send" << std::endl; + MPI_Send(send_g1, nsend * tensor_size, mpi_type, sendproc[iswap], 0, + world); + } + if (nrecv) { + std::cout << "wait" << std::endl; + MPI_Wait(&request, MPI_STATUS_IGNORE); + } + } else { +#ifdef USE_CUDA + cudaMemcpy(recv_g1, send_g1, nsend * tensor_size * sizeof(FPTYPE), + cudaMemcpyDeviceToDevice); +#else + memcpy(recv_g1, send_g1, nsend * tensor_size * sizeof(FPTYPE)); +#endif + } + recv_g1 += nrecv * tensor_size; + } + + return {recv_g1_tensor}; + } + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { +#ifdef USE_CUDA + cudaDeviceSynchronize(); +#endif + using FPTYPE = double; + torch::autograd::variable_list saved_variables = ctx->get_saved_variables(); + torch::Tensor sendlist_tensor = saved_variables[0]; + torch::Tensor sendproc_tensor = saved_variables[1]; + torch::Tensor recvproc_tensor = saved_variables[2]; + torch::Tensor sendnum_tensor = saved_variables[3]; + torch::Tensor recvnum_tensor = saved_variables[4]; + torch::Tensor communicator_tensor = saved_variables[5]; + torch::Tensor nlocal_tensor = saved_variables[6]; + torch::Tensor nghost_tensor = saved_variables[7]; + + torch::Tensor d_local_g1_tensor = grad_output[0]; + + int** recvlist = reinterpret_cast(sendlist_tensor.data_ptr()); + // swap send and recv here + int* recvproc = sendproc_tensor.data_ptr(); + int* sendproc = recvproc_tensor.data_ptr(); + int* recvnum = sendnum_tensor.data_ptr(); + int* sendnum = recvnum_tensor.data_ptr(); + + FPTYPE* local_g1 = d_local_g1_tensor.data_ptr(); + int tensor_size = d_local_g1_tensor.size(1); + int nswap = sendproc_tensor.size(0); + + int nlocal = nlocal_tensor.item(); + int nghost = nghost_tensor.item(); + int ntotal = nlocal + nghost; + + torch::Tensor send_g1_tensor = d_local_g1_tensor; + + int max_recvnum = sendnum_tensor.max().item(); + auto options = torch::TensorOptions() + .dtype(torch::kFloat64) + .device(d_local_g1_tensor.device()); + torch::Tensor recv_g1_tensor = + torch::empty({max_recvnum, tensor_size}, options); + FPTYPE* recv_g1 = recv_g1_tensor.data_ptr(); + FPTYPE* send_g1 = send_g1_tensor.data_ptr() + ntotal * tensor_size; + + MPI_Comm world; + unpack_communicator(communicator_tensor, world); + int me; + MPI_Comm_rank(world, &me); + MPI_Datatype mpi_type = get_mpi_type(); + MPI_Request request; + + std::string msg; + + int end = ntotal; + auto int32_options = torch::TensorOptions().dtype(torch::kInt32); + std::cout << "nswap backward" << nswap << std::endl; + for (int iswap = nswap - 1; iswap >= 0; --iswap) { + int nrecv = recvnum[iswap]; + int nsend = sendnum[iswap]; + + torch::Tensor irecvlist; + if (nrecv) { + irecvlist = torch::from_blob(recvlist[iswap], {nrecv}, int32_options) + .to(d_local_g1_tensor.device()); + } + if (nsend) { + send_g1 -= nsend * tensor_size; + } + if (sendproc[iswap] != me) { + if (nrecv) { + MPI_Irecv(recv_g1, nrecv * tensor_size, mpi_type, recvproc[iswap], 0, + world, &request); + } + if (nsend) { + MPI_Send(send_g1, nsend * tensor_size, mpi_type, sendproc[iswap], 0, + world); + } + if (nrecv) { + MPI_Wait(&request, MPI_STATUS_IGNORE); + } + } else { + if (nrecv) { +#ifdef USE_CUDA + cudaMemcpy(recv_g1, send_g1, nrecv * tensor_size * sizeof(FPTYPE), + cudaMemcpyDeviceToDevice); +#else + memcpy(recv_g1, send_g1, nrecv * tensor_size * sizeof(FPTYPE)); +#endif + } + } + if (nrecv) { + d_local_g1_tensor.index_add_(0, irecvlist, + recv_g1_tensor.slice(0, 0, nrecv)); + } + } +#ifdef USE_CUDA + cudaDeviceSynchronize(); +#endif + + return {torch::Tensor(), torch::Tensor(), torch::Tensor(), + torch::Tensor(), torch::Tensor(), d_local_g1_tensor, + torch::Tensor(), torch::Tensor(), torch::Tensor(), + torch::Tensor()}; + } + static void unpack_communicator(const torch::Tensor& communicator_tensor, + MPI_Comm& mpi_comm) { + int* communicator = communicator_tensor.data_ptr(); + mpi_comm = reinterpret_cast(*communicator); + } +}; +std::vector border_op(const torch::Tensor& sendlist_tensor, + const torch::Tensor& sendproc_tensor, + const torch::Tensor& recvproc_tensor, + const torch::Tensor& sendnum_tensor, + const torch::Tensor& recvnum_tensor, + const torch::Tensor& g1_tensor, + const torch::Tensor& communicator_tensor, + const torch::Tensor& nlocal_tensor, + const torch::Tensor& nghost_tensor) { + return Border::apply(sendlist_tensor, sendproc_tensor, recvproc_tensor, + sendnum_tensor, recvnum_tensor, g1_tensor, + communicator_tensor, nlocal_tensor, nghost_tensor); +} + +TORCH_LIBRARY_FRAGMENT(my_ops, m) { m.def("border_op", border_op); } diff --git a/source/op_pt/custom_op.h b/source/op_pt/custom_op.h new file mode 100644 index 0000000000..1b4dba62f1 --- /dev/null +++ b/source/op_pt/custom_op.h @@ -0,0 +1,3 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include +#include