Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lmp: support distributed DPA2 model inference #3440

Closed
wants to merge 10 commits into from
3 changes: 2 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from deepmd.pt.utils.env import (
DEVICE,
load_op,
)
from deepmd.pt.utils.finetune import (
change_finetune_model_params,
Expand Down Expand Up @@ -299,7 +300,7 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
set_log_handles(FLAGS.log_level, FLAGS.log_path, mpi_log=None)
log.debug("Log handles were successfully set")
log.info("DeepMD version: %s", __version__)

load_op()
if FLAGS.command == "train":
train(FLAGS)
elif FLAGS.command == "freeze":
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def forward_common_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.

Expand Down Expand Up @@ -153,6 +154,7 @@ def forward_common_atomic(
mapping=mapping,
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
)

# nf x nloc
Expand Down
6 changes: 2 additions & 4 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def forward_atomic(
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, torch.Tensor]:
"""Return atomic prediction.

Expand Down Expand Up @@ -163,10 +164,7 @@ def forward_atomic(
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
extended_coord, extended_atype, nlist, mapping=mapping, comm_dict=comm_dict
)
assert descriptor is not None
# energy, force
Expand Down
16 changes: 10 additions & 6 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -395,6 +396,7 @@ def forward(
extended_atype: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down Expand Up @@ -450,11 +452,12 @@ def forward(
# linear to change shape
g1 = self.g1_shape_tranform(g1)
# mapping g1
assert mapping is not None
mapping_ext = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
)
g1_ext = torch.gather(g1, 1, mapping_ext)
if comm_dict is None:
assert mapping is not None
# mapping_ext = (
# mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, g1.shape[-1])
# )
# g1_ext = torch.gather(g1, 1, mapping_ext)
# repformer
g1, g2, h2, rot_mat, sw = self.repformers(
nlist_dict[
Expand All @@ -464,8 +467,9 @@ def forward(
],
extended_coord,
extended_atype,
g1_ext,
g1,
mapping,
comm_dict,
)
if self.concat_output_tebd:
g1 = torch.cat([g1, g1_inp], dim=-1)
Expand Down
50 changes: 44 additions & 6 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,18 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@torch.jit.script_method
def forward(
self,
nlist: torch.Tensor,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_atype_embd: Optional[torch.Tensor] = None,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
assert mapping is not None
assert extended_atype_embd is not None
if comm_dict is None:
assert extended_atype_embd is not None
nframes, nloc, nnei = nlist.shape
nall = extended_coord.view(nframes, -1).shape[1] // 3
atype = extended_atype[:, :nloc]
Expand All @@ -257,8 +259,12 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# [nframes, nloc, tebd_dim]
atype_embd = extended_atype_embd[:, :nloc, :]
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
# atype_embd = extended_atype_embd[:, :nloc, :]
atype_embd = extended_atype_embd
if atype_embd is not None:
assert list(atype_embd.shape) == [nframes, nloc, self.g1_dim]
else:
raise NotImplementedError

g1 = self.act(atype_embd)
# nb x nloc x nnei x 1, nb x nloc x nnei x 3
Expand All @@ -275,11 +281,43 @@ def forward(
# if the a neighbor is real or not is indicated by nlist_mask
nlist[nlist == -1] = 0
# nb x nall x ng1
mapping = mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
if comm_dict is None:
assert mapping is not None
mapping = (
mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.g1_dim)
)
for idx, ll in enumerate(self.layers):
# g1: nb x nloc x ng1
# g1_ext: nb x nall x ng1
g1_ext = torch.gather(g1, 1, mapping)
if comm_dict is None:
assert mapping is not None
g1_ext = torch.gather(g1, 1, mapping)
else:
# padding = torch.zeros(nall-nloc, g1.size(2),device=mydev)
# g1 = torch.cat((g1.squeeze(0), padding), dim=0)
n_padding = nall - nloc
g1 = torch.nn.functional.pad(
g1.squeeze(0), (0, 0, 0, n_padding), value=0.0
)
assert "send_list" in comm_dict
assert "send_proc" in comm_dict
assert "recv_proc" in comm_dict
assert "send_num" in comm_dict
assert "recv_num" in comm_dict
assert "communicator" in comm_dict
ret = env.op_module.border_op(
comm_dict["send_list"],
comm_dict["send_proc"],
comm_dict["recv_proc"],
comm_dict["send_num"],
comm_dict["recv_num"],
g1,
comm_dict["communicator"],
torch.tensor(nloc),
torch.tensor(nall - nloc),
)
g1_ext = ret[0].unsqueeze(0)

g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def forward(
atype_ext: torch.Tensor,
nlist: torch.Tensor,
mapping: Optional[torch.Tensor] = None,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Compute the descriptor.

Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch


from .dp_model import (
DPModel,
)
Expand Down Expand Up @@ -69,6 +70,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -78,6 +80,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def forward_common_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
comm_dict: Optional[Dict[str, torch.Tensor]] = None,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand Down Expand Up @@ -224,6 +225,7 @@ def forward_common_lower(
mapping=mapping,
fparam=fp,
aparam=ap,
comm_dict=comm_dict,
)
model_predict = fit_output_to_model_output(
atomic_ret,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
from typing import (
Any,
)

import numpy as np
import torch
Expand Down Expand Up @@ -80,3 +83,12 @@
"ENERGY_BIAS_TRAINABLE",
"LOCAL_RANK",
]


def load_op():
torch.ops.load_library(
"/mnt/user/zhangxiangyu/workspace/dpkit/deepmd-kit/source/op_pt/libop_pt.so"
)


op_module: Any = torch.ops.my_ops

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable 'op_module' is not used.
13 changes: 13 additions & 0 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* numneigh_,
int** firstneigh_);

extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_,
int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
int world);

/**
* @brief Delete a neighbor list.
*
Expand Down
46 changes: 44 additions & 2 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,20 @@ inline double *_DP_Get_Energy_Pointer(double &vec, const int nframes) {

namespace deepmd {
namespace hpp {
// struct CommData {
// int nswap;
// int* sendnum;
// int* recvnum;
// int* firstrecv;
// int** sendlist;
// int* sendproc;
// int* recvproc;
// long int* world;

// CommData() : nswap(0), sendnum(nullptr), recvnum(nullptr),
// firstrecv(nullptr), sendlist(nullptr),
// sendproc(nullptr), recvproc(nullptr),world(nullptr) {}
// };
/**
* @brief Neighbor list.
**/
Expand All @@ -522,6 +536,36 @@ struct InputNlist {
nl(DP_NewNlist(inum_, ilist_, numneigh_, firstneigh_)) {
DP_CHECK_OK(DP_NlistCheckOK, nl);
};
InputNlist(int inum_,
int *ilist_,
int *numneigh_,
int **firstneigh_,
int nswap,
int *sendnum,
int *recvnum,
int *firstrecv,
int **sendlist,
int *sendproc,
int *recvproc,
int world)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
firstneigh(firstneigh_),
nl(DP_NewNlist_comm(inum_,
ilist_,
numneigh_,
firstneigh_,
nswap,
sendnum,
recvnum,
firstrecv,
sendlist,
sendproc,
recvproc,
world)){
// DP_CHECK_OK(DP_NlistCheckOK, nl);
};
~InputNlist() { DP_DeleteNlist(nl); };
/// @brief C API neighbor list.
DP_Nlist *nl;
Expand Down Expand Up @@ -798,7 +842,6 @@ class DeepPot {
aparam);
const VALUETYPE *fparam__ = !fparam_.empty() ? &fparam_[0] : nullptr;
const VALUETYPE *aparam__ = !aparam_.empty() ? &aparam_[0] : nullptr;

_DP_DeepPotComputeNList<VALUETYPE>(
dp, nframes, natoms, coord_, atype_, box_, nghost, lmp_list.nl, ago,
fparam__, aparam__, ener_, force_, virial_, nullptr, nullptr);
Expand Down Expand Up @@ -2117,6 +2160,5 @@ void select_map(std::vector<VT> &out,
out.resize(static_cast<size_t>(nall2) * stride);
DP_SelectMapInt(&in[0], &fwd_map[0], stride, nall1, nall2, &out[0]);
};

} // namespace hpp
} // namespace deepmd
19 changes: 18 additions & 1 deletion source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ DP_Nlist* DP_NewNlist(int inum_,
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_);
DP_Nlist* new_nl = new DP_Nlist(nl); return new_nl;)
}
DP_Nlist* DP_NewNlist_comm(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_,
int nswap,
int* sendnum,
int* recvnum,
int* firstrecv,
int** sendlist,
int* sendproc,
int* recvproc,
int world) {
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, nswap, sendnum,
recvnum, firstrecv, sendlist, sendproc, recvproc,
world);
DP_Nlist* new_nl = new DP_Nlist(nl);
return new_nl;
}

void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

Expand Down Expand Up @@ -268,7 +286,6 @@ inline void DP_DeepPotComputeNList_variant(DP_DeepPot* dp,
}
std::vector<double> e;
std::vector<VALUETYPE> f, v, ae, av;

DP_REQUIRES_OK(dp, dp->dp.compute(e, f, v, ae, av, coord_, atype_, cell_,
nghost, nlist->nl, ago, fparam_, aparam_));
// copy from C++ vectors to C arrays, if not NULL pointer
Expand Down
1 change: 1 addition & 0 deletions source/api_cc/include/DeepPotPT.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ class DeepPotPT : public DeepPotBase {
int gpu_id;
bool gpu_enabled;
at::Tensor firstneigh_tensor;
torch::Dict<std::string, torch::Tensor> comm_dict;
};

} // namespace deepmd
1 change: 1 addition & 0 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
namespace deepmd {

typedef double ENERGYTYPE;
// TODO: currently we only implement TF&PT; reserve for future use
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown };

struct NeighborListData {
Expand Down
Loading
Loading