Skip to content

Commit

Permalink
support DPA2 inference for 2024Q1 (#3756)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CaRoLZhangxy and pre-commit-ci[bot] authored May 8, 2024
1 parent 23f67a1 commit eed7c8a
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 7 deletions.
2 changes: 2 additions & 0 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
extern DP_Nlist* DP_NewNlist_mapping(
int inum_, int* ilist_, int* numneigh_, int** firstneigh_, int* mapping);

/**
* @brief Delete a neighbor list.
Expand Down
11 changes: 11 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,18 @@ struct InputNlist {
nl(DP_NewNlist(inum_, ilist_, numneigh_, firstneigh_)) {
DP_CHECK_OK(DP_NlistCheckOK, nl);
};
InputNlist(
int inum_, int *ilist_, int *numneigh_, int **firstneigh_, int *mapping)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
firstneigh(firstneigh_),
nl(DP_NewNlist_mapping(
inum_, ilist_, numneigh_, firstneigh_, mapping)) {
DP_CHECK_OK(DP_NlistCheckOK, nl);
};
~InputNlist() { DP_DeleteNlist(nl); };

/// @brief C API neighbor list.
DP_Nlist *nl;
/// @brief Number of core region atoms
Expand Down
7 changes: 6 additions & 1 deletion source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ DP_Nlist* DP_NewNlist(int inum_,
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_);
DP_Nlist* new_nl = new DP_Nlist(nl); return new_nl;)
}

DP_Nlist* DP_NewNlist_mapping(
int inum_, int* ilist_, int* numneigh_, int** firstneigh_, int* mapping) {
deepmd::InputNlist nl(inum_, ilist_, numneigh_, firstneigh_, mapping);
DP_Nlist* new_nl = new DP_Nlist(nl);
return new_nl;
}
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

DP_DeepPot::DP_DeepPot() {}
Expand Down
10 changes: 8 additions & 2 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,13 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
bool do_atom_virial_tensor = true;
c10::optional<torch::Tensor> optional_tensor;
c10::optional<torch::Tensor> mapping_tensor;
if (lmp_list.mapping != nullptr) {
mapping_tensor =
torch::from_blob(lmp_list.mapping, {1, nall_real}, int32_options)
.to(torch::kInt64)
.to(device);
}
c10::optional<torch::Tensor> fparam_tensor;
if (!fparam.empty()) {
fparam_tensor =
Expand All @@ -174,7 +180,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
c10::Dict<c10::IValue, c10::IValue> outputs =
module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
firstneigh_tensor, optional_tensor, fparam_tensor,
firstneigh_tensor, mapping_tensor, fparam_tensor,
aparam_tensor, do_atom_virial_tensor)
.toGenericDict();
c10::IValue energy_ = outputs.at("energy");
Expand Down
14 changes: 12 additions & 2 deletions source/lib/include/neighbor_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,22 @@ struct InputNlist {
int* numneigh;
/// Array stores the core region atom's neighbor index
int** firstneigh;
InputNlist() : inum(0), ilist(NULL), numneigh(NULL), firstneigh(NULL){};
int* mapping;
InputNlist()
: inum(0), ilist(NULL), numneigh(NULL), firstneigh(NULL), mapping(NULL){};
InputNlist(int inum_, int* ilist_, int* numneigh_, int** firstneigh_)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
firstneigh(firstneigh_){};
firstneigh(firstneigh_),
mapping(NULL){};
InputNlist(
int inum_, int* ilist_, int* numneigh_, int** firstneigh_, int* mapping)
: inum(inum_),
ilist(ilist_),
numneigh(numneigh_),
firstneigh(firstneigh_),
mapping(mapping){};
~InputNlist(){};
};

Expand Down
11 changes: 9 additions & 2 deletions source/lmp/pair_deepmd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ void PairDeepMD::compute(int eflag, int vflag) {
double **x = atom->x;
double **f = atom->f;
int *type = atom->type;
int *tag_array = atom->tag;
int nlocal = atom->nlocal;
int nghost = 0;
if (do_ghost) {
Expand All @@ -484,7 +485,11 @@ void PairDeepMD::compute(int eflag, int vflag) {
}
}
}

// make mapping array
int *mapping = new int[nall];
for (int i = 0; i < nall; ++i) {
mapping[i] = atom->map(tag_array[i]);
}
vector<int> dtype(nall);
for (int ii = 0; ii < nall; ++ii) {
dtype[ii] = type_idx_map[type[ii] - 1];
Expand Down Expand Up @@ -551,7 +556,7 @@ void PairDeepMD::compute(int eflag, int vflag) {
(numb_models > 1 && (out_freq > 0 && update->ntimestep % out_freq == 0));
if (do_ghost) {
deepmd_compat::InputNlist lmp_list(list->inum, list->ilist, list->numneigh,
list->firstneigh);
list->firstneigh, mapping);
deepmd_compat::InputNlist extend_lmp_list;
if (atom->sp_flag) {
extend(extend_inum, extend_ilist, extend_numneigh, extend_neigh,
Expand Down Expand Up @@ -1280,6 +1285,8 @@ void PairDeepMD::coeff(int narg, char **arg) {
void PairDeepMD::init_style() {
#if LAMMPS_VERSION_NUMBER >= 20220324
neighbor->add_request(this, NeighConst::REQ_FULL);
atom->map_user = 2;
atom->map_init(1);
#else
int irequest = neighbor->request(this, instance_me);
neighbor->requests[irequest]->half = 0;
Expand Down

0 comments on commit eed7c8a

Please sign in to comment.